## DUSP1 Dataframe Concatenation and Replica Check
- Concatenate experimental dataframes based on specific conditions:
    - **100nM 3hr Time-Sweep (TS):** Data collected over a 3-hour time period with a fixed concentration of 100nM.
    - **75min Concentration-Sweep:** Data collected over varying concentrations during a 75-minute time period.
    - **3hr Time-Concentration Sweep (TCS):** Data collected over a 3-hour period with varying concentrations.
    - **Triptiolide (TPL):** Data collected under conditions involving Triptiolide treatment.
- Perform a replica check to ensure data consistency and identify any discrepancies across experimental replicates.
- Document and visualize the concatenated data for further analysis.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os
import sys
import datetime

import h5py
import dask.array as da

# src_path = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
# print(src_path)
# sys.path.append(src_path)

# from src.Analysis_DUSP1_v2 import DUSP1DisplayManager

## DUSP1 Experiments

In [None]:
# Define your directory
df_directory = '/Users/ericron/Desktop/AngelFISH/Publications/Ron_2024/Classification'

# Initialize containers
spots_total = []
clusters_total = []
cellprops_total = []
cellresults_total = []

# List all files
all_files = os.listdir(df_directory)

# Sort files into categories
for file in all_files:
    filepath = os.path.join(df_directory, file)
    
    if file.endswith('merged_spots_df_MG3_Abs4_Apr24.csv'):
        spots_total.append(pd.read_csv(filepath))
    elif file.endswith('merged_clusters_df_MG3_Abs4_Apr24.csv'):
        clusters_total.append(pd.read_csv(filepath))
    elif file.endswith('merged_cellprops_df_MG3_Abs4_Apr24.csv'):
        cellprops_total.append(pd.read_csv(filepath))
    elif file.endswith('cell_level_results_MG3_Abs4_Apr24.csv'):
        cellresults_total.append(pd.read_csv(filepath))

# Concatenate into single DataFrames
spots_total = pd.concat(spots_total, ignore_index=True)
clusters_total = pd.concat(clusters_total, ignore_index=True)
cellprops_total = pd.concat(cellprops_total, ignore_index=True)
cellresults_total = pd.concat(cellresults_total, ignore_index=True)

In [None]:
# Make a copy of the DUSP1 data
DUSP1_data = cellresults_total.copy()

# Experiment 1: 100 nM Dex time sweep with 12 timepoints
df_expt1 = DUSP1_data[DUSP1_data['replica'].isin(['D', 'E', 'F', 'M', 'N'])]
expt1_timepoints = [10, 20, 30, 40, 50, 60, 75, 90, 120, 150, 180]
expt1_concs = [100]

# Experiment 2: 75min concentration sweep with 8 concentrations
df_expt2 = DUSP1_data[DUSP1_data['replica'].isin(['G', 'H', 'I'])]
expt2_timepoints = [75]
expt2_concs = [0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000]

# Experiment 3: 0.3, 1, 10nM Dex time sweep with 6 timepoints
df_expt3 = DUSP1_data[DUSP1_data['replica'].isin(['J', 'K', 'L'])]
expt3_timepoints = [30, 50, 75, 90, 120, 180]
expt3_concs = [0.3, 1, 10]

# Calculate means for each replica
replica_means = DUSP1_data.groupby(['dex_conc', 'time', 'replica']).agg({
    'nuc_MG_count': 'mean',
    'cyto_MG_count': 'mean',
    'MG_count': 'mean',
}).reset_index()

# Calculate the mean and standard deviation of the replica means
summary_stats = replica_means.groupby(['dex_conc', 'time']).agg({
    'nuc_MG_count': ['mean', 'std'],
    'cyto_MG_count': ['mean', 'std'],
    'MG_count': ['mean', 'std'],
}).reset_index()

# Rename columns for easier access
summary_stats.columns = [
    'dex_conc', 'time',
    'mean_nuc_count', 'std_nuc_count',
    'mean_cyto_count', 'std_cyto_count',
    'mean_MG_count', 'std_MG_count'
]

# Calculate overall mean and standard deviation for each concentration and time point
overall_stats = DUSP1_data.groupby(['dex_conc', 'time']).agg({
    'nuc_MG_count': ['mean', 'std'],
    'cyto_MG_count': ['mean', 'std'],
    'MG_count': ['mean', 'std'],
}).reset_index()

# Rename columns for easier access
overall_stats.columns = [
    'dex_conc', 'time',
    'overall_mean_nuc', 'overall_std_nuc',
    'overall_mean_cyto', 'overall_std_cyto',
    'overall_mean_MG', 'overall_std_MG'
]

# Extract 0 min data (shared baseline from dex_conc == 0)
zero_min_summary = summary_stats[summary_stats['time'] == 0]
zero_min_overall = overall_stats[overall_stats['time'] == 0]

# Set Style
sns.set_theme(style="ticks", palette="colorblind", context="poster", font='times new roman')

# Define the color palette for Nuclear and Cytoplasmic intensities
colors_nuc_cyto = sns.color_palette("colorblind", 2)  # Two colors: one for Nuclear, one for Cytoplasmic

# Loop through the three experiments
experiments = {
    "Experiment 1: 100 nM Time Sweep": (expt1_concs, expt1_timepoints),
    # "Experiment 2: 75 min Concentration Sweep": (expt2_concs, expt2_timepoints),
    "Experiment 3: 0.3, 1, 10 nM Time Sweep": (expt3_concs, expt3_timepoints)
}

for expt_name, (concs, timepoints) in experiments.items():
    for conc in concs:
        # Filter data for plotting
        subset_summary = summary_stats[(summary_stats['dex_conc'] == conc) & (summary_stats['time'].isin(timepoints))]
        subset_overall = overall_stats[(overall_stats['dex_conc'] == conc) & (overall_stats['time'].isin(timepoints))]

        # Add 0 min time point to all subsets if not already present
        if 0 not in subset_summary['time'].values:
            subset_summary = pd.concat([zero_min_summary, subset_summary], ignore_index=True)
        if 0 not in subset_overall['time'].values:
            subset_overall = pd.concat([zero_min_overall, subset_overall], ignore_index=True)

        plt.figure(figsize=(10, 5))

        # Plot Nuclear mRNA Count Mean with Error Bars
        plt.errorbar(subset_summary['time'], subset_summary['mean_nuc_count'],
                     yerr=subset_summary['std_nuc_count'], fmt='-o', color=colors_nuc_cyto[0], capsize=5,
                     label='Nuclear mRNA Count Replicas')

        # Filling between std deviations for overall data - Nuclear
        plt.fill_between(subset_overall['time'],
                         subset_overall['overall_mean_nuc'] - subset_overall['overall_std_nuc'],
                         subset_overall['overall_mean_nuc'] + subset_overall['overall_std_nuc'],
                         color=colors_nuc_cyto[0], alpha=0.2, label='Total Data Spread - Nuclear')

        # Plot Cytoplasmic mRNA Count Mean with Error Bars
        plt.errorbar(subset_summary['time'], subset_summary['mean_cyto_count'],
                     yerr=subset_summary['std_cyto_count'], fmt='-o', color=colors_nuc_cyto[1], capsize=5,
                     label='Cytoplasmic mRNA Count Replicas')

        # Filling between std deviations for overall data - Cytoplasmic
        plt.fill_between(subset_overall['time'],
                         subset_overall['overall_mean_cyto'] - subset_overall['overall_std_cyto'],
                         subset_overall['overall_mean_cyto'] + subset_overall['overall_std_cyto'],
                         color=colors_nuc_cyto[1], alpha=0.2, label='Total Data Spread - Cytoplasmic')
        
        # Plot the total mRNA Count Mean with Error Bars
        plt.errorbar(subset_summary['time'], subset_summary['mean_MG_count'],
                    yerr=subset_summary['std_MG_count'], fmt='-o', color='black', capsize=5,
                    label='Total mRNA Count Replicas')
        # Filling between std deviations for overall data - Total
# After
        plt.fill_between(subset_overall['time'],
                        subset_overall['overall_mean_MG'] - subset_overall['overall_std_MG'],
                        subset_overall['overall_mean_MG'] + subset_overall['overall_std_MG'],
                        color='black', alpha=0.2, label='Total Data Spread - Total')
        # Set x-ticks to be the time points
        plt.xticks(subset_summary['time'], rotation=45)

        # Customize the plot
        plt.title(f"{expt_name} - {conc} nM Dex", fontsize=18, fontweight='bold')
        plt.xlabel('Time (min)', fontsize=14)
        plt.ylabel('mRNA Spot Count', fontsize=14)
        plt.grid(True)
        plt.legend(loc='upper left', fontsize=12, frameon=False, bbox_to_anchor=(1, 1))


        # Show the plot
        plt.show()


In [None]:
# Loop through the three experiments
experiments = {
    "Experiment 1: 100 nM Time Sweep": (expt1_concs, expt1_timepoints),
    # "Experiment 2: 75 min Concentration Sweep": (expt2_concs, expt2_timepoints),
    "Experiment 3: 0.3, 1, 10 nM Time Sweep": (expt3_concs, expt3_timepoints)
}

for expt_name, (concs, timepoints) in experiments.items():
    for conc in concs:
        # Filter data for plotting
        subset_summary = summary_stats[(summary_stats['dex_conc'] == conc) & (summary_stats['time'].isin(timepoints))]
        subset_overall = overall_stats[(overall_stats['dex_conc'] == conc) & (overall_stats['time'].isin(timepoints))]

        # Add 0 min time point to all subsets if not already present
        if 0 not in subset_summary['time'].values:
            subset_summary = pd.concat([zero_min_summary, subset_summary], ignore_index=True)
        if 0 not in subset_overall['time'].values:
            subset_overall = pd.concat([zero_min_overall, subset_overall], ignore_index=True)

        # Create figure with three horizontal subplots
        fig, axes = plt.subplots(1, 3, figsize=(20, 5), sharey=False)

        # Plot configs
        titles = ['Nuclear', 'Cytoplasmic', 'Total']
        colors = [colors_nuc_cyto[0], colors_nuc_cyto[1], 'black']
        means = ['mean_nuc_count', 'mean_cyto_count', 'mean_MG_count']
        stds = ['std_nuc_count', 'std_cyto_count', 'std_MG_count']
        overall_means = ['overall_mean_nuc', 'overall_mean_cyto', 'overall_mean_MG']
        overall_stds = ['overall_std_nuc', 'overall_std_cyto', 'overall_std_MG']
        labels = ['Nuclear mRNA Count Replicas', 'Cytoplasmic mRNA Count Replicas', 'Total mRNA Count Replicas']

        for i, ax in enumerate(axes):
            # Plot mean with error bars
            ax.errorbar(subset_summary['time'], subset_summary[means[i]],
                        yerr=subset_summary[stds[i]], fmt='-o', color=colors[i], capsize=5,
                        label=labels[i])

            # Plot shaded region (overall mean ± std)
            ax.fill_between(subset_overall['time'],
                            subset_overall[overall_means[i]] - subset_overall[overall_stds[i]],
                            subset_overall[overall_means[i]] + subset_overall[overall_stds[i]],
                            color=colors[i], alpha=0.2)

            ax.set_title(titles[i], fontsize=16)
            ax.set_xlabel('Time (min)')
            ax.grid(True)
            ax.set_xticks(subset_summary['time'])
            ax.tick_params(axis='x', rotation=45)

            if i == 0:
                ax.set_ylabel('mRNA Count')

        # # Put a **single legend** outside to the right
        # handles, labels = axes[0].get_legend_handles_labels()
        # fig.legend(handles, labels, loc='center left', bbox_to_anchor=(1.05, 0.5), frameon=False, fontsize=12)

        # Adjust title:
        if "Experiment 3" in expt_name:
            fig.suptitle(f"{conc} nM Dex", fontsize=18, fontweight='bold')
        else:
            fig.suptitle(f"{expt_name} - {conc} nM Dex", fontsize=18, fontweight='bold')

        # Adjust layout
        plt.tight_layout(rect=[0, 0, 0.9, 0.95])  # Leave space for suptitle and legend
        plt.show()

In [None]:
import matplotlib.ticker as ticker
import numpy as np

# Now handling Experiment 2 separately
expt2_name = "Experiment 2: 75 min Concentration Sweep"
concs = expt2_concs
timepoints = expt2_timepoints

# Set a small value for 0 concentration
zero_conc_value = 1e-3  # small fake value to represent 0 on log axis

for timepoint in timepoints:
    # Full subset including 0 nM
    subset_summary = summary_stats[(summary_stats['time'] == timepoint) & (summary_stats['dex_conc'].isin(concs))]
    subset_overall = overall_stats[(overall_stats['time'] == timepoint) & (overall_stats['dex_conc'].isin(concs))]

    # Modify 0 nM entries
    subset_summary = subset_summary.copy()
    subset_overall = subset_overall.copy()
    subset_summary.loc[subset_summary['dex_conc'] == 0, 'dex_conc'] = zero_conc_value
    subset_overall.loc[subset_overall['dex_conc'] == 0, 'dex_conc'] = zero_conc_value

    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(20, 5), sharey=False)

    # Plot configs
    titles = ['Nuclear', 'Cytoplasmic', 'Total']
    colors = [colors_nuc_cyto[0], colors_nuc_cyto[1], 'black']
    means = ['mean_nuc_count', 'mean_cyto_count', 'mean_MG_count']
    stds = ['std_nuc_count', 'std_cyto_count', 'std_MG_count']
    overall_means = ['overall_mean_nuc', 'overall_mean_cyto', 'overall_mean_MG']
    overall_stds = ['overall_std_nuc', 'overall_std_cyto', 'overall_std_MG']
    labels = ['Nuclear mRNA Count Replicas', 'Cytoplasmic mRNA Count Replicas', 'Total mRNA Count Replicas']

    # All tested concentrations (after replacing 0)
    all_concs_for_ticks = [zero_conc_value, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000]

    # Define the matching labels
    tick_labels = ['0', '1 pM', '10 pM', '100 pM', '1 nM', '10 nM', '100 nM', '1 µM', '10 µM']

    for i, ax in enumerate(axes):
        # Plot mean with error bars
        ax.errorbar(subset_summary['dex_conc'], subset_summary[means[i]],
                    yerr=subset_summary[stds[i]], fmt='-o', color=colors[i], capsize=5,
                    label=labels[i])

        # Shaded region
        ax.fill_between(subset_overall['dex_conc'],
                        subset_overall[overall_means[i]] - subset_overall[overall_stds[i]],
                        subset_overall[overall_means[i]] + subset_overall[overall_stds[i]],
                        color=colors[i], alpha=0.2)

        ax.set_title(titles[i], fontsize=16)
        ax.set_xlabel('Dex Concentration')
        ax.set_xscale('log')

        # Set manual ticks and labels
        ax.set_xlim(left=zero_conc_value/2, right=1.2e4)
        ax.set_xticks(all_concs_for_ticks)
        ax.set_xticklabels(tick_labels)

        # Only major gridlines
        ax.grid(True, which='major', axis='x', linestyle='--', alpha=0.6)

        # Remove gridline at fake 0
        for line in ax.get_xgridlines():
            if np.isclose(line.get_xdata()[0], zero_conc_value, rtol=1e-2):
                line.set_visible(False)

        ax.tick_params(axis='x', rotation=45)

        if i == 0:
            ax.set_ylabel('mRNA Count')

    # Title
    fig.suptitle(f"{expt2_name} - {timepoint} min", fontsize=18, fontweight='bold')

    # Layout
    plt.tight_layout(rect=[0, 0, 0.9, 0.95])
    plt.show()

## Replica comparison

In [None]:
from scipy.stats import ks_2samp, wasserstein_distance, rankdata
from itertools import combinations
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [None]:
# Pairwise comparison function
def compute_pairwise_comparisons(df, metrics_to_compare):
    pairwise_results = []
    unique_conditions = df[['time', 'dex_conc']].drop_duplicates().sort_values(by=['time', 'dex_conc']).values
    
    for time_val, conc_val in unique_conditions:
        condition_df = df[(df['time'] == time_val) & (df['dex_conc'] == conc_val)]
        replicas_present = condition_df['replica'].unique()
        
        if len(replicas_present) < 2:
            continue  # skip if only 1 replica present
        
        for rep_a, rep_b in combinations(replicas_present, 2):
            df_a = condition_df[condition_df['replica'] == rep_a]
            df_b = condition_df[condition_df['replica'] == rep_b]
            
            for metric, column in metrics_to_compare:
                values_a = df_a[column].values
                values_b = df_b[column].values
                
                ks_stat, ks_pvalue = ks_2samp(values_a, values_b)
                wass_dist = wasserstein_distance(values_a, values_b)
                
                pairwise_results.append({
                    'Time_min': time_val,
                    'Dex_conc_nM': conc_val,
                    'Replica_A': rep_a,
                    'Replica_B': rep_b,
                    'Metric': metric,
                    'KS_pvalue': ks_pvalue,
                    'Wasserstein_distance': wass_dist,
                    'Num_cells_A': len(values_a),
                    'Num_cells_B': len(values_b)
                })
    
    return pd.DataFrame(pairwise_results)


# Heatmap plotting function
def plot_wasserstein_heatmap(pairwise_df, time_val, conc_val, metric):
    subset = pairwise_df[(pairwise_df['Time_min'] == time_val) & 
                         (pairwise_df['Dex_conc_nM'] == conc_val) & 
                         (pairwise_df['Metric'] == metric)]
    
    replicas = np.unique(np.concatenate([subset['Replica_A'].unique(), subset['Replica_B'].unique()]))
    replicas_sorted = sorted(replicas)
    
    dist_matrix = pd.DataFrame(np.nan, index=replicas_sorted, columns=replicas_sorted)
    
    for _, row in subset.iterrows():
        a, b = row['Replica_A'], row['Replica_B']
        dist = row['Wasserstein_distance']
        dist_matrix.loc[a, b] = dist
        dist_matrix.loc[b, a] = dist
    
    np.fill_diagonal(dist_matrix.values, 0)
    
    plt.figure(figsize=(6, 5))
    sns.heatmap(dist_matrix, annot=True, fmt=".1f", cmap='viridis', cbar_kws={'label': 'Wasserstein Distance'})
    plt.title(f"Wasserstein Distance Heatmap\nTime={time_val} min, Conc={conc_val} nM, Metric={metric}")
    plt.xlabel("Replica")
    plt.ylabel("Replica")
    plt.tight_layout()
    plt.show()

def generate_summary_table(pairwise_df, distance_threshold=40.0):
    summary_rows = []
    unique_conditions = pairwise_df[['Time_min', 'Dex_conc_nM']].drop_duplicates().values
    
    for time_val, conc_val in unique_conditions:
        subset = pairwise_df[(pairwise_df['Time_min'] == time_val) & 
                             (pairwise_df['Dex_conc_nM'] == conc_val)]
        for metric in subset['Metric'].unique():
            metric_subset = subset[subset['Metric'] == metric]
            
            if len(metric_subset) == 0:
                continue
            
            max_dist_row = metric_subset.loc[metric_subset['Wasserstein_distance'].idxmax()]
            max_dist = max_dist_row['Wasserstein_distance']
            replica_a = max_dist_row['Replica_A']
            replica_b = max_dist_row['Replica_B']
            
            num_high_dist = (metric_subset['Wasserstein_distance'] > distance_threshold).sum()
            total_comparisons = len(metric_subset)
            
            summary_rows.append({
                'Time_min': time_val,
                'Dex_conc_nM': conc_val,
                'Metric': metric,
                'Max_Wasserstein_distance': max_dist,
                'Replica_A_Max': replica_a,
                'Replica_B_Max': replica_b,
                'Num_high_distance_pairs': num_high_dist,
                'Total_pairs': total_comparisons
            })
    
    return pd.DataFrame(summary_rows)


def plot_ecdf_per_replica(df, time_val, conc_val, metric_name, column_name):
    subset = df[(df['time'] == time_val) & (df['dex_conc'] == conc_val)]
    replicas_present = subset['replica'].unique()
    
    plt.figure(figsize=(7, 5))
    
    for replica in sorted(replicas_present):
        replica_values = subset[subset['replica'] == replica][column_name].values
        x = np.sort(replica_values)
        y = np.arange(1, len(x)+1) / len(x)
        plt.step(x, y, where='post', label=f'Replica {replica}')
    
    plt.xlabel(metric_name)
    plt.ylabel('ECDF')
    plt.title(f'ECDF per Replica\nTime={time_val} min, Conc={conc_val} nM, Metric={metric_name}')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()

# Identify outlier cells based on replica disagreement
def identify_outlier_cells(df, pairwise_df, metrics_to_compare, distance_threshold=40.0, percentile_cutoff=5):
    """
    Identify individual outlier cells based on replica disagreement.

    Args:
        df (pd.DataFrame): full dataframe of cells, must include 'unique_cell_id', 'replica', 'time', 'dex_conc', metric columns.
        pairwise_df (pd.DataFrame): output of compute_pairwise_comparisons function.
        metrics_to_compare (list): list of tuples, e.g. [('nuc', 'num_nuc_spots'), ...]
        distance_threshold (float): minimum Wasserstein distance required to trigger analysis of a (time, conc, metric).
        percentile_cutoff (float): percentile cutoff for outliers (default 5 → flags top/bottom 5%).

    Returns:
        pd.DataFrame: table of outlier cells with columns:
            unique_cell_id, replica, time, dex_conc, metric, cell_value, percentile_rank
    """
    outlier_cells = []
    unique_conditions = pairwise_df[['Time_min', 'Dex_conc_nM']].drop_duplicates().values
    
    for time_val, conc_val in unique_conditions:
        subset_pairwise = pairwise_df[(pairwise_df['Time_min'] == time_val) & 
                                      (pairwise_df['Dex_conc_nM'] == conc_val)]
        
        for metric, column_name in metrics_to_compare:
            metric_subset = subset_pairwise[subset_pairwise['Metric'] == metric]
            
            if len(metric_subset) == 0:
                continue
            
            # Check if this condition is "outlier-worthy"
            max_dist = metric_subset['Wasserstein_distance'].max()
            if max_dist < distance_threshold:
                continue  # skip non-outlier condition
            
            # Now analyze this condition + metric
            condition_df = df[(df['time'] == time_val) & (df['dex_conc'] == conc_val)]
            replicas_present = condition_df['replica'].unique()
            
            for replica in replicas_present:
                this_replica_values = condition_df[condition_df['replica'] == replica][column_name].values
                this_replica_cells = condition_df[condition_df['replica'] == replica]['unique_cell_id'].values
                
                other_replicas_values = condition_df[condition_df['replica'] != replica][column_name].values
                
                if len(other_replicas_values) == 0:
                    continue  # skip if no other replicas
                
                # Compute percentile rank for each cell
                for cell_id, value in zip(this_replica_cells, this_replica_values):
                    # Append value to consensus and rank it
                    rank = rankdata(np.append(other_replicas_values, value))
                    percentile_rank = 100.0 * (rank[-1] - 1) / len(other_replicas_values)
                    
                    # Check if this is an outlier
                    if percentile_rank < percentile_cutoff or percentile_rank > (100 - percentile_cutoff):
                        outlier_cells.append({
                            'unique_cell_id': cell_id,
                            'replica': replica,
                            'time': time_val,
                            'dex_conc': conc_val,
                            'metric': metric,
                            'cell_value': value,
                            'percentile_rank': percentile_rank
                        })
    
    return pd.DataFrame(outlier_cells)

def compute_replica_stability(pairwise_df):
    """
    Compute replica stability scores: mean Wasserstein distance of each replica
    to other replicas, per (time, conc, metric).

    Args:
        pairwise_df (pd.DataFrame): output of compute_pairwise_comparisons.

    Returns:
        pd.DataFrame: stability scores with columns:
            Time_min, Dex_conc_nM, Metric, Replica, Mean_Wasserstein_distance, Num_pairs
    """
    stability_rows = []
    unique_conditions = pairwise_df[['Time_min', 'Dex_conc_NM'] if 'Dex_conc_NM' in pairwise_df.columns else 'Dex_conc_nM'].drop_duplicates().values
    
    # Backward compatibility in case Dex_conc_NM typo appears
    conc_col = 'Dex_conc_NM' if 'Dex_conc_NM' in pairwise_df.columns else 'Dex_conc_nM'

    for time_val, conc_val in unique_conditions:
        subset = pairwise_df[(pairwise_df['Time_min'] == time_val) & 
                             (pairwise_df[conc_col] == conc_val)]
        
        for metric in subset['Metric'].unique():
            metric_subset = subset[subset['Metric'] == metric]
            
            # Get all replicas involved
            replicas = np.unique(np.concatenate([metric_subset['Replica_A'].unique(), 
                                                 metric_subset['Replica_B'].unique()]))
            
            for replica in replicas:
                # Select all rows where this replica is Replica_A or Replica_B
                is_rep_a = metric_subset['Replica_A'] == replica
                is_rep_b = metric_subset['Replica_B'] == replica
                replica_rows = metric_subset[is_rep_a | is_rep_b]
                
                mean_dist = replica_rows['Wasserstein_distance'].mean()
                num_pairs = len(replica_rows)
                
                stability_rows.append({
                    'Time_min': time_val,
                    'Dex_conc_nM': conc_val,
                    'Metric': metric,
                    'Replica': replica,
                    'Mean_Wasserstein_distance': mean_dist,
                    'Num_pairs': num_pairs
                })
    
    return pd.DataFrame(stability_rows)

def plot_replica_stability_heatmap(stability_df, metric_filter=None):
    """
    Plot heatmap of replica stability scores.

    Args:
        stability_df (pd.DataFrame): output of compute_replica_stability.
        metric_filter (str or None): if specified, only plot this metric (e.g. 'nuc'). If None, plot all.

    Returns:
        None (displays heatmap)
    """
    df_plot = stability_df.copy()
    
    # Optional: filter by metric
    if metric_filter is not None:
        df_plot = df_plot[df_plot['Metric'] == metric_filter]
    
    # Create a composite column for columns
    df_plot['Condition'] = df_plot.apply(lambda row: f"T{int(row['Time_min'])}_C{row['Dex_conc_nM']}_M{row['Metric']}", axis=1)
    
    # Pivot table: rows = replica, columns = condition, values = mean distance
    pivot = df_plot.pivot_table(index='Replica', columns='Condition', values='Mean_Wasserstein_distance')
    
    # Plot heatmap
    plt.figure(figsize=(max(8, len(pivot.columns) * 0.6), max(5, len(pivot.index) * 0.5)))
    sns.heatmap(pivot, annot=True, fmt=".1f", cmap='Reds', cbar_kws={'label': 'Mean Wasserstein Distance'})
    plt.title(f"Replica Stability Heatmap{' - Metric: ' + metric_filter if metric_filter else ''}")
    plt.xlabel("Condition (Time_min, Conc_nM, Metric)")
    plt.ylabel("Replica")
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()    

In [None]:
# Define metrics to compare
metrics_to_compare = [
    ('nuc', 'num_nuc_spots'),
    ('cyto', 'num_cyto_spots'),
    ('ts', 'num_ts'),
    ('foci', 'num_foci')
]

# Run the pairwise comparison
pairwise_results_df = compute_pairwise_comparisons(df, metrics_to_compare)

# Save pairwise results (optional)
# pairwise_results_df.to_csv("pairwise_results.csv", index=False)

# Plot an example heatmap
plot_wasserstein_heatmap(pairwise_results_df, time_val=75, conc_val=100.0, metric='nuc')

In [None]:
# Generate the summary table
summary_df = generate_summary_table(pairwise_results_df, distance_threshold=40.0)

# Save summary (optional)
# summary_df.to_csv("summary_table.csv", index=False)

# Show top conditions with max distance
summary_df.sort_values(by='Max_Wasserstein_distance', ascending=False).head(10)

# Example ECDF plot
plot_ecdf_per_replica(df, time_val=75, conc_val=100.0, metric_name='num_nuc_spots', column_name='num_nuc_spots')

In [None]:
# Run the outlier cell identification
outlier_cells_df = identify_outlier_cells(df, pairwise_results_df, metrics_to_compare, 
                                          distance_threshold=40.0, percentile_cutoff=5)

# Save or inspect
outlier_cells_df.to_csv("identified_outlier_cells.csv", index=False)
outlier_cells_df.head()

In [None]:
replica_stability_df = compute_replica_stability(pairwise_results_extended_df)

# Sort to find unstable replicas
replica_stability_df.sort_values(by='Mean_Wasserstein_distance', ascending=False).head(10)