In [None]:
import os
import glob
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from scipy.stats import mannwhitneyu
from itertools import combinations

# 1. Import the analysis function
from scripts.Rim_enrichment import analyze_rim_foci_enrichment

In [None]:
main_dir = "path/to/your/input_data"
coloc_results_dir = "FBL_coloc_results_HCT116"
output_dir = "Rim_Enrichment_Results_FBL"
target_component = "FBL"

all_results = []
groups = [d for d in os.listdir(main_dir) if os.path.isdir(os.path.join(main_dir, d))]

# 3. Batch Processing Loop
for group in tqdm(groups, desc="Groups"):
    group_path = os.path.join(main_dir, group)
    nd2_files = glob.glob(os.path.join(group_path, "*.nd2"))
    
    for nd2_file in tqdm(nd2_files, desc=f"Files in {group}", leave=False):
        file_stem = os.path.splitext(os.path.basename(nd2_file))[0]
        coloc_folder_path = os.path.join(coloc_results_dir, f"{group}_{file_stem}")
        
        # Define paths to masks from previous step
        nuclei_path   = os.path.join(coloc_folder_path, "nuclei_labels.tif")
        nucleoli_path = os.path.join(coloc_folder_path, "nucleoli_labels.tif")
        foci_path     = os.path.join(coloc_folder_path, f"{target_component}_mask.tif")
        raw_target    = os.path.join(coloc_folder_path, f"raw_{target_component}.png")
        
        if os.path.exists(nuclei_path) and os.path.exists(foci_path):
            try:
                df, _ = analyze_rim_foci_enrichment(
                    nd2_file=nd2_file,
                    nuclei_labels_path=nuclei_path,
                    nucleoli_labels_path=nucleoli_path,
                    foci_mask_path=foci_path,
                    target_intensity_path=raw_target,
                    target_name=target_component,
                    output_dir=output_dir,
                    group_name=group,
                    save_visuals=True
                )
                
                if not df.empty:
                    df["Group"] = group
                    df["File"] = file_stem
                    all_results.append(df)
            except Exception as e:
                print(f"Error analyzing rim for {file_stem}: {e}")

# 4. Statistics and Plotting

if all_results:
    final_df = pd.concat(all_results, ignore_index=True)
    
    # Save Combined Data
    final_df.to_csv(os.path.join(output_dir, f"All_Groups_{target_component}_Rim_Enrichment.csv"), index=False)
    print(f"Saved combined data with {len(final_df)} cells.")
    
    # -------------------------------------------------------------------------
    # NEW: Statistical Analysis (Mann-Whitney U Test)
    # -------------------------------------------------------------------------
    print("\n--- Running Mann-Whitney U Tests ---")
    
    # Define metrics for analysis
    metrics = [
        ("prop_foci_area_in_rim", "Area Fraction in Rim (Crowding)"),
        (f"score_area_{target_component}", "Area Enrichment Score"),
        ("prop_intensity_in_rim", "Intensity Fraction in Rim (Mass)"),
        (f"score_intensity_{target_component}", "Intensity Enrichment Score (Concentration)")
    ]
    
    stats_list = []
    unique_groups = final_df['Group'].unique()
    
    # Generate all possible pairs of groups (e.g., Control vs Treated)
    group_pairs = list(combinations(unique_groups, 2))
    
    for metric, metric_name in metrics:
        if metric not in final_df.columns:
            continue
            
        for g1, g2 in group_pairs:
            # Extract data for the two groups
            data_g1 = final_df[final_df['Group'] == g1][metric].dropna()
            data_g2 = final_df[final_df['Group'] == g2][metric].dropna()
            
            if len(data_g1) > 0 and len(data_g2) > 0:
                # Perform Mann-Whitney U Test (two-sided)
                stat, p_val = mannwhitneyu(data_g1, data_g2, alternative='two-sided')
                
                # Determine significance stars
                if p_val < 0.0001: sig = "****"
                elif p_val < 0.001: sig = "***"
                elif p_val < 0.01: sig = "**"
                elif p_val < 0.05: sig = "*"
                else: sig = "ns"
                
                stats_list.append({
                    "Metric": metric,
                    "Comparison": f"{g1} vs {g2}",
                    "Group1": g1,
                    "Group2": g2,
                    "p-value": p_val,
                    "Significance": sig,
                    "U-stat": stat
                })
                
                print(f"[{metric_name}] {g1} vs {g2}: p={p_val:.4e} ({sig})")

    # Save Stats to CSV
    if stats_list:
        stats_df = pd.DataFrame(stats_list)
        stats_path = os.path.join(output_dir, f"Stats_MannWhitney_{target_component}.csv")
        stats_df.to_csv(stats_path, index=False)
        print(f"\nStatistical results saved to: {stats_path}")

    # -------------------------------------------------------------------------
    # Plotting (Visuals)
    # -------------------------------------------------------------------------
    sns.set_style("whitegrid")
    
    # Create a 2x2 grid for the 4 metrics
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    axes = axes.ravel() 
    
    for ax, (metric, title) in zip(axes, metrics):
        
        if metric not in final_df.columns:
            print(f"Skipping plot for {metric} (not in dataframe)")
            continue

        sns.violinplot(
            data=final_df, 
            x="Group", 
            y=metric, 
            ax=ax, 
            inner="box", 
            palette="muted"
        )
        
        sns.stripplot(
            data=final_df, 
            x="Group", 
            y=metric, 
            ax=ax, 
            color="black", 
            alpha=0.2, 
            size=2, 
            jitter=True
        )
        
        ax.set_title(title, fontsize=12, fontweight='bold')
        ax.set_xlabel("")
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
        
        # Add reference line for Enrichment Scores
        if "score" in metric:
            ax.axhline(1.0, color='red', linestyle='--', alpha=0.5, label="No Enrichment (1.0)")
            if ax == axes[1]: 
                ax.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"Violin_Plots_{target_component}_Rim_Enrichment_FULL.png"), dpi=300)
    plt.show()
    
else:
    print("No results found. Check paths and ensure get_CoLoc was run successfully.")