In [None]:
# %%
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns
from scipy import stats
from scipy.optimize import curve_fit
import os
from typing import List, Tuple, Dict, Optional, Union
import glob
import matplotlib.patches as mpatches

# Import the phase_dominance module
import phase_dominance as pd_analysis

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook", font_scale=1.2)

# %% [markdown]
# # Phase Dominance Analysis for 3D Systems
# 
# This script analyzes phase dominance in 3D systems across multiple experimental
# conditions. It processes multiple input files, creates slices along a specified
# dimension, and visualizes dominance trends within and across slices.

# %%
# Configuration section - Customize these parameters for your dataset

# File loading settings
file_config = {
    'data_directory': '',                  # Directory containing data files
    'file_pattern': '*.csv',                     # Pattern to match files (e.g., '*.csv')
    'experiment_labels': None,                   # Optional: List of labels for each experiment. If None, filenames will be used
}

# Column names in your dataset
column_config = {
    # Main component columns (required for 3D analysis)
    'comp_A_col': '647_conc_tot_CT',                # e.g., 
    'comp_B_col': '488_conc_tot_CT',                # e.g., 
    'comp_C_col': '546_conc_tot_CT',                # e.g., 
    
    # Phase separation state column
    'phase_sep_col': 'PS',                      # Column indicating phase separation state (boolean)
    
    # Intensity measurement columns
    'dilute_intensity_col': '5-25%_488_CT',     # Raw dilute phase intensity  
    'total_intensity_col': '488_tot_CT',        # Raw total intensity
    
    # Additional columns
    'condensate_label_col': 'Condensate label',  # Column for different condensate types
    'frame_col': 'continuous_frame'              # Column for frame or time point information
}

# Component labels for plots
labels = { 
    'comp_A_label': '',       # Label for Component A
    'comp_B_label': '',        # Label for Component B
    'comp_C_label': '',          # Label for Component C
}

# Concentration limits for each component - set to None to auto-detect from data
limits = {
    'comp_A_lims': [0,100],             # Component A concentration range
    'comp_B_lims': [0,10],             # Component B concentration range
    'comp_C_lims': [50,200],             # Component C concentration range
}

# Analysis configuration
analysis_config = {
    # Slicing parameters
    'slice_dimension': 'C',             # Which component to slice on ('A', 'B', or 'C')
    'n_slices': 4,                      # Number of slices for the selected dimension

    # Analysis parameters
    'grid_resolution': 30,              # Resolution of the grid for phase diagram visualization
    'min_droplets_per_cell': 5,         # Minimum number of droplets per grid cell
    'n_windows': 10,                    # Number of windows for dominance calculation
    'n_response_plots': 3,              # Number of response plots to show per experiment and slice
    'min_data_threshold': 15,           # Minimum data points required for fitting
    'combined_phase_diagram': True,     # Whether to create a combined phase diagram for all slices
    'protein_role': 'B',                # Which component to analyze for dominance ('A' or 'B')
    'compare_experiments': True,        # Whether to compare dominance across experiments
    'compare_slices': True,             # Whether to compare dominance across slices
}

# Visualization settings
vis_config = {
    'phase_diagram_cmap': 'coolwarm',    # Colormap for phase diagrams
    'dominance_color_palette': 'viridis', # Color palette for dominance plots
    'slice_color_palette': 'Blues_r',    # Color palette for comparing slices
    'figure_dpi': 150,                   # DPI for saved figures
    'save_figures': True,                # Whether to save figures
    'output_directory': './outputs',     # Directory for saved figures
}

# %%
# Helper functions for the analysis


In [None]:

def load_experimental_data(config: Dict) -> Tuple[List[pd.DataFrame], List[str]]:
    """
    Load multiple experimental datasets from files.
    
    Parameters:
    ----------
    config : Dict
        Configuration dictionary with file loading settings
    
    Returns:
    -------
    Tuple[List[pd.DataFrame], List[str]]
        List of dataframes and corresponding experiment labels
    """
    # Get list of files matching pattern
    file_pattern = os.path.join(config['data_directory'], config['file_pattern'])
    file_list = sorted(glob.glob(file_pattern))
    
    if not file_list:
        # If no files found, generate demo data
        print(f"No files found matching {file_pattern}. Generating demo data...")
        datasets = [generate_demo_data(n_points=5000, seed_offset=i) for i in range(3)]
        
        # Create demo labels
        if config['experiment_labels'] is None:
            labels = [f"Demo_Experiment_{i+1}" for i in range(len(datasets))]
        else:
            labels = config['experiment_labels']
            if len(labels) < len(datasets):
                labels.extend([f"Demo_Experiment_{i+1}" for i in range(len(labels), len(datasets))])
    else:
        # Load actual data files
        print(f"Found {len(file_list)} files matching pattern.")
        datasets = []
        
        for file_path in file_list:
            try:
                df = pd.read_csv(file_path)
                datasets.append(df)
                print(f"Loaded {os.path.basename(file_path)} with {len(df)} rows")
            except Exception as e:
                print(f"Error loading {file_path}: {str(e)}")
        
        # Create labels from filenames if not provided
        if config['experiment_labels'] is None:
            labels = [os.path.splitext(os.path.basename(f))[0] for f in file_list]
        else:
            labels = config['experiment_labels']
            if len(labels) < len(datasets):
                base_labels = [os.path.splitext(os.path.basename(f))[0] for f in file_list]
                labels.extend(base_labels[len(labels):])
    
    return datasets, labels

def generate_demo_data(n_points=5000, seed_offset=0):
    """Generate a simulated dataset for demonstration purposes."""
    np.random.seed(42 + seed_offset)
    
    # Generate random concentrations for three components
    comp_A = np.random.uniform(0, 2, n_points)  # Component A (e.g., lysate)
    comp_B = np.random.uniform(10, 40, n_points)  # Component B (e.g., protein)
    comp_C = np.random.uniform(120, 250, n_points)  # Component C (e.g., crowding agent)
    
    # Define a simple phase separation model (this is a simplified example)
    phase_sep_prob = 1 / (1 + np.exp(-((comp_A - 1 + 0.2*seed_offset) * 5 + 
                                     (comp_B - 25 - seed_offset) * 0.2 - 
                                     (comp_C - 180 + 10*seed_offset) * 0.05)))
    phase_separated = np.random.random(n_points) < phase_sep_prob
    
    # Calculate dilute phase concentrations (simplified model)
    dilute_factor_A = np.where(phase_separated, 
                              0.5 + 0.3 * np.random.random(n_points) - (0.01 + 0.005*seed_offset) * comp_B, 
                              0.95 + 0.05 * np.random.random(n_points))
    
    dilute_A = comp_A * np.clip(dilute_factor_A, 0.1, 0.99)
    
    # Create DataFrame
    df = pd.DataFrame({
        'comp_A_conc': comp_A,  # Component A total concentration
        'comp_B_conc': comp_B,  # Component B total concentration
        'comp_C_conc': comp_C,  # Component C total concentration
        'PS': phase_separated,  # Phase separation state
        'dilute_A_intensity': dilute_A,  # Dilute phase intensity for Component A
        'total_A_intensity': comp_A,  # Total intensity for Component A
        'Condensate label': 'DefaultCondensate',
        'continuous_frame': np.random.randint(0, 100, n_points)
    })
    
    # Add a 'Phase separated' column as a copy of 'PS'
    df['Phase separated'] = df['PS']
    
    return df

def determine_concentration_limits(datasets: List[pd.DataFrame], 
                                  col_names: Dict, 
                                  config_limits: Dict) -> Dict:
    """
    Determine concentration limits for components, using either the provided 
    limits or auto-detecting from the data.
    
    Parameters:
    ----------
    datasets : List[pd.DataFrame]
        List of dataframes containing the experimental data
    col_names : Dict
        Dictionary of column names
    config_limits : Dict
        User-provided concentration limits (may contain None values)
    
    Returns:
    -------
    Dict
        Dictionary with determined concentration limits
    """
    comp_A_col = col_names['comp_A_col']
    comp_B_col = col_names['comp_B_col']
    comp_C_col = col_names['comp_C_col']
    
    # Initialize with user-provided limits
    limits_out = config_limits.copy()
    
    # Determine Component A limits if not provided
    if limits_out['comp_A_lims'] is None:
        min_val = min([df[comp_A_col].min() for df in datasets])
        max_val = max([df[comp_A_col].max() for df in datasets])
        # Add a small buffer
        range_val = max_val - min_val
        buffer = range_val * 0.05
        limits_out['comp_A_lims'] = (max(0, min_val - buffer), max_val + buffer)
    
    # Determine Component B limits if not provided
    if limits_out['comp_B_lims'] is None:
        min_val = min([df[comp_B_col].min() for df in datasets])
        max_val = max([df[comp_B_col].max() for df in datasets])
        # Add a small buffer
        range_val = max_val - min_val
        buffer = range_val * 0.05
        limits_out['comp_B_lims'] = (max(0, min_val - buffer), max_val + buffer)
    
    # Determine Component C limits if not provided
    if limits_out['comp_C_lims'] is None:
        min_val = min([df[comp_C_col].min() for df in datasets])
        max_val = max([df[comp_C_col].max() for df in datasets])
        # Add a small buffer
        range_val = max_val - min_val
        buffer = range_val * 0.05
        limits_out['comp_C_lims'] = (max(0, min_val - buffer), max_val + buffer)
    
    return limits_out

def create_slices(df: pd.DataFrame, 
                 slice_dim: str,
                 slice_col: str, 
                 slice_lims: Tuple[float, float],
                 n_slices: int) -> Tuple[List[pd.DataFrame], List[float], float]:
    """
    Create slices of the dataset along a specified dimension.
    
    Parameters:
    ----------
    df : pd.DataFrame
        The dataset to slice
    slice_dim : str
        Dimension to slice on ('A', 'B', or 'C')
    slice_col : str
        Column name for the slice dimension
    slice_lims : Tuple[float, float]
        Limits for the slice dimension
    n_slices : int
        Number of slices to create
    
    Returns:
    -------
    Tuple[List[pd.DataFrame], List[float], float]
        List of sliced dataframes, list of slice centers, and slice window size
    """
    # Create slice edges
    slice_edges = np.linspace(slice_lims[0], slice_lims[1], n_slices + 1)
    slice_window_size = slice_edges[1] - slice_edges[0]
    slice_centers = [edge + slice_window_size/2 for edge in slice_edges[:-1]]
    
    # Create slices
    slices = []
    for i in range(n_slices):
        slice_min = slice_edges[i]
        slice_max = slice_edges[i+1]
        
        # Filter data for this slice
        slice_df = df[(df[slice_col] >= slice_min) & (df[slice_col] < slice_max)].copy()
        slices.append(slice_df)
        
        print(f"Slice {i+1}/{n_slices}: {slice_dim} = {slice_min:.2f}-{slice_max:.2f}, {len(slice_df)} points")
    
    return slices, slice_centers, slice_window_size

def plot_sliced_phase_diagrams(slices: List[pd.DataFrame], 
                             slice_centers: List[float],
                             slice_window_size: float,
                             x_col: str, 
                             y_col: str,
                             phase_sep_col: str,
                             x_lims: Tuple[float, float],
                             y_lims: Tuple[float, float],
                             slice_label: str,
                             x_label: str,
                             y_label: str,
                             config: Dict) -> plt.Figure:
    """
    Create a grid of phase diagrams for each slice.
    
    Parameters:
    ----------
    slices : List[pd.DataFrame]
        List of sliced dataframes
    slice_centers : List[float]
        Center values for each slice
    slice_window_size : float
        Width of each slice
    x_col : str
        Column name for x-axis
    y_col : str
        Column name for y-axis
    phase_sep_col : str
        Column name for phase separation state
    x_lims : Tuple[float, float]
        Limits for x-axis
    y_lims : Tuple[float, float]
        Limits for y-axis
    slice_label : str
        Label for the slice dimension
    x_label : str
        Label for x-axis
    y_label : str
        Label for y-axis
    config : Dict
        Visualization configuration
    
    Returns:
    -------
    plt.Figure
        The generated figure
    """
    n_slices = len(slices)
    
    # Calculate grid dimensions (try to make it roughly square)
    grid_cols = int(np.ceil(np.sqrt(n_slices)))
    grid_rows = int(np.ceil(n_slices / grid_cols))
    
    # Create figure
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(4*grid_cols, 4*grid_rows),
                           sharex=True, sharey=True)
    
    # Flatten axes array for easier indexing
    if grid_rows > 1 or grid_cols > 1:
        axes = axes.flatten()
    else:
        axes = [axes]
    
    # Get colormap
    cmap = plt.get_cmap(config['phase_diagram_cmap'])
    
    # Plot each slice
    for i, (slice_df, center) in enumerate(zip(slices, slice_centers)):
        if i < len(axes):
            # Plot phase diagram for this slice
            ax = pd_analysis.plot_shaded_grid(
                df=slice_df,
                x_col=x_col,
                y_col=y_col,
                avg_col=phase_sep_col,
                x_lim=x_lims,
                y_lim=y_lims,
                N=analysis_config['grid_resolution'],
                minDroplets=analysis_config['min_droplets_per_cell'],
                cmap=cmap,
                vlims=(0, 1),
                ax=axes[i]
            )
            
            axes[i].set_title(f'{slice_label} = {center:.2f} ± {slice_window_size/2:.2f}')
            axes[i].set_xlabel(f'{x_label}')
            axes[i].set_ylabel(f'{y_label}')
    
    # Hide any unused subplots
    for i in range(len(slices), len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    
    # Save figure if configured
    if config['save_figures']:
        os.makedirs(config['output_directory'], exist_ok=True)
        fig_path = os.path.join(config['output_directory'], 'sliced_phase_diagrams.png')
        plt.savefig(fig_path, dpi=config['figure_dpi'], bbox_inches='tight')
        print(f"Saved sliced phase diagrams to {fig_path}")
    
    return fig

def analyze_slices_dominance(slices: List[pd.DataFrame],
                           slice_centers: List[float],
                           slice_label: str,
                           experiment_label: str,
                           protein_col: str,
                           protein_lims: Tuple[float, float],
                           protein_label: str,
                           modulator_col: str,
                           modulator_lims: Tuple[float, float], 
                           modulator_label: str,
                           dilute_intensity_col: str,
                           total_intensity_col: str,
                           phase_sep_col: str,
                           condensate_label_col: str,
                           analysis_config: Dict,
                           vis_config: Dict) -> Dict:
    """
    Analyze dominance across slices for a single experiment.
    
    Parameters:
    ----------
    slices : List[pd.DataFrame]
        List of sliced dataframes
    slice_centers : List[float]
        Center values for each slice
    slice_label : str
        Label for the slice dimension
    experiment_label : str
        Label for the experiment
    protein_col : str
        Column name for protein concentration
    protein_lims : Tuple[float, float]
        Limits for protein concentration
    protein_label : str
        Label for the protein
    modulator_col : str
        Column name for modulator concentration
    modulator_lims : Tuple[float, float]
        Limits for modulator concentration
    modulator_label : str
        Label for the modulator
    dilute_intensity_col : str
        Column name for dilute intensity
    total_intensity_col : str
        Column name for total intensity
    phase_sep_col : str
        Column name for phase separation state
    condensate_label_col : str
        Column name for condensate label
    analysis_config : Dict
        Analysis configuration
    vis_config : Dict
        Visualization configuration
    
    Returns:
    -------
    Dict
        Dictionary containing analysis results
    """
    # Lists to store results from each slice
    all_slice_responses = []
    all_slice_response_errors = []
    all_slice_modulator_values = []
    slice_results = []
    
    # Color palette for slice comparison
    slice_colors = sns.color_palette(vis_config['slice_color_palette'], len(slices))
    
    # Analyze each slice
    for i, (slice_df, center) in enumerate(zip(slices, slice_centers)):
        print(f"  Analyzing dominance for slice {i+1}/{len(slices)}: {slice_label} = {center:.2f}")
        
        # Ensure condensate label column exists
        if condensate_label_col not in slice_df.columns:
            slice_df[condensate_label_col] = 'DefaultCondensate'
        
        if len(slice_df) < analysis_config['min_data_threshold']:
            print(f"  Skipping slice {i+1} - insufficient data points ({len(slice_df)})")
            continue
        
        # Calculate dominance for this slice
        ps_responses, nps_responses, ps_mod_values, nps_mod_values, ps_errors, nps_errors = pd_analysis.dominance_sweep(
            df=slice_df,
            protein_conc_col=protein_col,
            protein_min=protein_lims[0],
            protein_max=protein_lims[1],
            modulator_conc_col=modulator_col,
            modulator_min=modulator_lims[0],
            modulator_max=modulator_lims[1],
            condensate_label=slice_df[condensate_label_col].iloc[0],
            numberWindows=analysis_config['n_windows'],
            numberResponsePlots=analysis_config['n_response_plots'],
            minimumDataThreshold=analysis_config['min_data_threshold'],
            modulator_label=modulator_label,
            protein_label=protein_label,
            phase_sep_col=phase_sep_col,
            dilute_intensity_col=dilute_intensity_col,
            total_intensity_col=total_intensity_col
        )
        
        # Store results
        all_slice_responses.append(ps_responses)
        all_slice_response_errors.append(ps_errors)
        all_slice_modulator_values.append(ps_mod_values)
        slice_results.append({
            'experiment': experiment_label,
            'slice': i,
            'slice_center': center,
            'phase_sep_responses': ps_responses,
            'phase_sep_errors': ps_errors,
            'non_phase_sep_responses': nps_responses,
            'non_phase_sep_errors': nps_errors,
            'phase_sep_modulator_values': ps_mod_values,
            'non_phase_sep_modulator_values': nps_mod_values
        })
    
    # Create comparative plot across slices if we have results from multiple slices
    if len(all_slice_responses) > 1 and analysis_config['compare_slices']:
        slice_fig = pd_analysis.plot_dominance_comparison(
            response_list=all_slice_responses,
            modulator_values_list=all_slice_modulator_values,
            response_errors_list=all_slice_response_errors,
            slice_labels=[slice_label],
            slice_values=[result['slice_center'] for result in slice_results],
            slice_window_size=slices[1][slice_col].max() - slices[1][slice_col].min() if len(slices) > 1 else 0,
            modulator_label=modulator_label,
            response_component_label=protein_label,
            figsize=(10, 6),
            custom_colors=slice_colors,
            confidence_level=0.95,
            show_errors=True
        )
        
        # Save figure if configured
        if vis_config['save_figures']:
            os.makedirs(vis_config['output_directory'], exist_ok=True)
            fig_path = os.path.join(vis_config['output_directory'], 
                                    f'{experiment_label}_slices_comparison.png')
            slice_fig.savefig(fig_path, dpi=vis_config['figure_dpi'], bbox_inches='tight')
            print(f"Saved slice comparison for {experiment_label} to {fig_path}")
        
        plt.show()
    
    # Return results dictionary
    return {
        'experiment': experiment_label,
        'slice_responses': all_slice_responses,
        'slice_response_errors': all_slice_response_errors,
        'slice_modulator_values': all_slice_modulator_values,
        'slice_centers': [result['slice_center'] for result in slice_results],
        'slice_results': slice_results
    }

def plot_response_difference(
    results: List[Dict],
    modulator_label: str = "Modulator",
    protein_label: str = "Component",
    figsize: Tuple[int, int] = (10, 6),
    custom_colors: Optional[List] = None,
    show_errors: bool = True,
    confidence_level: float = 0.95,
    legend_loc: str = 'best'
) -> plt.Figure:
    """
    Plot the difference between non-phase separated and phase separated responses.
    
    Parameters:
    ----------
    results : List[Dict]
        List of result dictionaries from analyze_and_plot_dominance.
    modulator_label : str, optional
        Label for the modulator axis. Defaults to "Modulator".
    protein_label : str, optional
        Label for the protein component. Defaults to "Component".
    figsize : Tuple[int, int], optional
        Figure size. Defaults to (10, 6).
    custom_colors : Optional[List], optional
        Custom color palette. If None, a default colormap is used.
    show_errors : bool, optional
        Whether to show error bands. Defaults to True.
    confidence_level : float, optional
        Confidence level for error bands. Defaults to 0.95.
    legend_loc : str, optional
        Location of the legend. Defaults to 'best'.
    
    Returns:
    -------
    plt.Figure
        The generated matplotlib figure.
    """
    if custom_colors is None:
        colors = sns.color_palette('viridis', len(results))
    else:
        colors = custom_colors
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Calculate appropriate z-score for the confidence level
    from scipy import stats
    z_score = stats.norm.ppf((1 + confidence_level) / 2)
    
    # Extract experiment labels
    experiment_labels = [result['experiment'] for result in results]
    
    for i, result in enumerate(results):
        ps_responses = result.get('phase_sep_responses', [])
        nps_responses = result.get('non_phase_sep_responses', [])
        ps_mod_values = result.get('phase_sep_modulator_values', [])
        nps_mod_values = result.get('non_phase_sep_modulator_values', [])
        ps_errors = result.get('phase_sep_errors', [])
        nps_errors = result.get('non_phase_sep_errors', [])
        
        # Find common modulator values where we have both PS and NPS data
        common_mod_values = []
        ps_indices = []
        nps_indices = []
        
        for j, mod_val in enumerate(ps_mod_values):
            if mod_val in nps_mod_values:
                nps_idx = nps_mod_values.index(mod_val)
                common_mod_values.append(mod_val)
                ps_indices.append(j)
                nps_indices.append(nps_idx)
        
        if not common_mod_values:
            print(f"No common modulator values found for {result['experiment']}")
            continue
        
        # Calculate the differences (NPS - PS)
        differences = [nps_responses[nps_idx] - ps_responses[ps_idx] 
                      for ps_idx, nps_idx in zip(ps_indices, nps_indices)]
        
        # Calculate combined error if available
        if show_errors and ps_errors and nps_errors and len(ps_errors) > 0 and len(nps_errors) > 0:
            combined_errors = [np.sqrt(ps_errors[ps_idx]**2 + nps_errors[nps_idx]**2) 
                              for ps_idx, nps_idx in zip(ps_indices, nps_indices)]
            
            # Calculate upper and lower confidence bounds
            lower_bound = [diff - z_score * err for diff, err in zip(differences, combined_errors)]
            upper_bound = [diff + z_score * err for diff, err in zip(differences, combined_errors)]
        else:
            combined_errors = None
        
        # Plot the difference curve
        ax.plot(
            common_mod_values, 
            differences,
            '-o', 
            color=colors[i],
            label=experiment_labels[i],
            linewidth=2,
            markersize=5
        )
        
        # Add confidence intervals if error values provided and show_errors is True
        if show_errors and combined_errors:
            ax.fill_between(
                common_mod_values,
                lower_bound,
                upper_bound,
                color=colors[i],
                alpha=0.2
            )
    
    # Add a zero line for reference
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
    
    # Set up the plot
    ax.set_xlabel(f'{modulator_label} Concentration')
    ax.set_ylabel(f'Response Difference (Mixed - Demixed)')
    ax.grid(True, alpha=0.3)
    ax.legend(loc=legend_loc)
    
    # Add confidence level to title if showing errors
    title = "Difference Between Mixed and Demixed Responses"
    if show_errors:
        title += f" with {int(confidence_level*100)}% Confidence Intervals"
    ax.set_title(title)
    
    plt.tight_layout()
    return fig

def compare_experiments_across_slices(experiment_results: List[Dict],
                                     experiment_labels: List[str],
                                     slice_label: str,
                                     modulator_label: str,
                                     protein_label: str,
                                     vis_config: Dict) -> plt.Figure:
    """
    Create a grid of plots comparing experiments across slices.
    
    Parameters:
    ----------
    experiment_results : List[Dict]
        List of experiment result dictionaries
    experiment_labels : List[str]
        Labels for each experiment
    slice_label : str
        Label for the slice dimension
    modulator_label : str
        Label for the modulator
    protein_label : str
        Label for the protein
    vis_config : Dict
        Visualization configuration
    
    Returns:
    -------
    plt.Figure
        The generated figure
    """
    # Find all unique slice centers across experiments
    all_slice_centers = []
    for result in experiment_results:
        all_slice_centers.extend(result['slice_centers'])
    unique_slice_centers = sorted(list(set(all_slice_centers)))
    
    # Calculate grid dimensions (try to make it roughly square)
    n_slices = len(unique_slice_centers)
    grid_cols = int(np.ceil(np.sqrt(n_slices)))
    grid_rows = int(np.ceil(n_slices / grid_cols))
    
    # Create figure
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(5*grid_cols, 4*grid_rows), 
                           sharex=True, sharey=True)
    
    # Flatten axes array for easier indexing
    if grid_rows > 1 or grid_cols > 1:
        axes = axes.flatten()
    else:
        axes = [axes]
    
    # Color palette for experiments
    exp_colors = sns.color_palette(vis_config['dominance_color_palette'], len(experiment_results))
    
    # For each slice center, create a plot comparing experiments
    for i, center in enumerate(unique_slice_centers):
        if i < len(axes):
            # Find results for this slice center from each experiment
            for j, result in enumerate(experiment_results):
                # Find the closest slice center in this experiment
                if result['slice_centers']:
                    closest_idx = np.argmin(np.abs(np.array(result['slice_centers']) - center))
                    closest_center = result['slice_centers'][closest_idx]
                    
                    # Only plot if the slice center is close enough (within 10% of the range)
                    if abs(closest_center - center) < 0.1 * (max(unique_slice_centers) - min(unique_slice_centers)):
                        # Get responses and modulator values for this slice
                        responses = result['slice_responses'][closest_idx]
                        mod_values = result['slice_modulator_values'][closest_idx]
                        
                        # Plot dominance curve
                        axes[i].plot(
                            mod_values,
                            [1 - r for r in responses],
                            '-o',
                            color=exp_colors[j],
                            label=experiment_labels[j],
                            linewidth=2,
                            markersize=4
                        )
            
            # Set up the plot
            axes[i].set_title(f'{slice_label} = {center:.2f}')
            axes[i].set_xlabel(f'{modulator_label}')
            axes[i].set_ylabel(f'{protein_label} Dominance')
            axes[i].set_ylim(-0.1, 1.1)
            axes[i].grid(True, alpha=0.3)
            
            # Only add legend to the first plot
            if i == 0:
                axes[i].legend(loc='best')
    
    # Hide any unused subplots
    for i in range(len(unique_slice_centers), len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    
    # Save figure if configured
    if vis_config['save_figures']:
        os.makedirs(vis_config['output_directory'], exist_ok=True)
        fig_path = os.path.join(vis_config['output_directory'], 'experiment_comparison_across_slices.png')
        plt.savefig(fig_path, dpi=vis_config['figure_dpi'], bbox_inches='tight')
        print(f"Saved experiment comparison across slices to {fig_path}")
    
    return fig

# %%
# Main execution block

# Create output directory if needed
if vis_config['save_figures']:
    os.makedirs(vis_config['output_directory'], exist_ok=True)

# Load experimental data
print("Loading experimental data...")
datasets, experiment_labels = load_experimental_data(file_config)
print(f"Loaded {len(datasets)} datasets with labels: {experiment_labels}")

# Determine concentration limits
limits = determine_concentration_limits(datasets, column_config, limits)
print(f"Concentration limits - Component A: {limits['comp_A_lims']}, Component B: {limits['comp_B_lims']}, Component C: {limits['comp_C_lims']}")

# Extract column names and labels
comp_A_col = column_config['comp_A_col']
comp_B_col = column_config['comp_B_col']
comp_C_col = column_config['comp_C_col']
phase_sep_col = column_config['phase_sep_col']
dilute_intensity_col = column_config['dilute_intensity_col']
total_intensity_col = column_config['total_intensity_col']
condensate_label_col = column_config['condensate_label_col']

comp_A_label = labels['comp_A_label']
comp_B_label = labels['comp_B_label']
comp_C_label = labels['comp_C_label']

comp_A_lims = limits['comp_A_lims']
comp_B_lims = limits['comp_B_lims']
comp_C_lims = limits['comp_C_lims']

# Determine slice dimension and parameters
slice_dim = analysis_config['slice_dimension']
if slice_dim == 'A':
    slice_col = comp_A_col
    slice_lims = comp_A_lims
    slice_label = comp_A_label
    x_col = comp_B_col
    y_col = comp_C_col
    x_lims = comp_B_lims
    y_lims = comp_C_lims
    x_label = comp_B_label
    y_label = comp_C_label
elif slice_dim == 'B':
    slice_col = comp_B_col
    slice_lims = comp_B_lims
    slice_label = comp_B_label
    x_col = comp_A_col
    y_col = comp_C_col
    x_lims = comp_A_lims
    y_lims = comp_C_lims
    x_label = comp_A_label
    y_label = comp_C_label
else:  # Default to C
    slice_col = comp_C_col
    slice_lims = comp_C_lims
    slice_label = comp_C_label
    x_col = comp_B_col
    y_col = comp_A_col
    x_lims = comp_B_lims
    y_lims = comp_A_lims
    x_label = comp_B_label
    y_label = comp_A_label

# Determine protein and modulator for dominance analysis
if analysis_config['protein_role'] == 'A':
    protein_col = comp_A_col
    protein_lims = comp_A_lims
    protein_label = comp_A_label
    modulator_col = comp_B_col
    modulator_lims = comp_B_lims
    modulator_label = comp_B_label
else:  # 'B'
    protein_col = comp_B_col
    protein_lims = comp_B_lims
    protein_label = comp_B_label
    modulator_col = comp_A_col
    modulator_lims = comp_A_lims
    modulator_label = comp_A_label

# Initialize list to store all experiment results
all_experiment_results = []

# Process each experiment
for i, (dataset, exp_label) in enumerate(zip(datasets, experiment_labels)):
    print(f"\nProcessing experiment {i+1}/{len(datasets)}: {exp_label}")
    
    # Create slices for this experiment
    print(f"Creating {analysis_config['n_slices']} slices along dimension {slice_dim}...")
    exp_slices, slice_centers, slice_window_size = create_slices(
        df=dataset,
        slice_dim=slice_dim,
        slice_col=slice_col,
        slice_lims=slice_lims,
        n_slices=analysis_config['n_slices']
    )
    
    # Plot phase diagrams for slices if this is the first experiment
    if i == 0 and analysis_config['combined_phase_diagram']:
        print("Creating combined phase diagram for slices...")
        phase_fig = plot_sliced_phase_diagrams(
            slices=exp_slices,
            slice_centers=slice_centers,
            slice_window_size=slice_window_size,
            x_col=x_col,
            y_col=y_col,
            phase_sep_col=phase_sep_col,
            x_lims=x_lims,
            y_lims=y_lims,
            slice_label=slice_label,
            x_label=x_label,
            y_label=y_label,
            config=vis_config
        )
        plt.show()
    
    # Analyze dominance across slices
    print("Analyzing dominance across slices...")
    exp_results = analyze_slices_dominance(
        slices=exp_slices,
        slice_centers=slice_centers,
        slice_label=slice_label,
        experiment_label=exp_label,
        protein_col=protein_col,
        protein_lims=protein_lims,
        protein_label=protein_label,
        modulator_col=modulator_col,
        modulator_lims=modulator_lims,
        modulator_label=modulator_label,
        dilute_intensity_col=dilute_intensity_col,
        total_intensity_col=total_intensity_col,
        phase_sep_col=phase_sep_col,
        condensate_label_col=condensate_label_col,
        analysis_config=analysis_config,
        vis_config=vis_config
    )
    
    # Store results
    all_experiment_results.append(exp_results)
    
    # Plot response differences for this experiment's slices
    for j, slice_result in enumerate(exp_results['slice_results']):
        if j % 2 == 0:  # Only plot every other slice to avoid too many plots
            print(f"Plotting response difference for slice {j+1}...")
            diff_fig = plot_response_difference(
                results=[slice_result],
                modulator_label=modulator_label,
                protein_label=protein_label,
                figsize=(8, 5),
                show_errors=True,
                confidence_level=0.95
            )
            
            # Save figure if configured
            if vis_config['save_figures']:
                diff_fig_path = os.path.join(vis_config['output_directory'], 
                                           f'{exp_label}_slice{j+1}_response_diff.png')
                diff_fig.savefig(diff_fig_path, dpi=vis_config['figure_dpi'], bbox_inches='tight')
                
            plt.show()

# Compare experiments across slices if we have multiple experiments
if len(all_experiment_results) > 1 and analysis_config['compare_experiments']:
    print("\nComparing experiments across slices...")
    comparison_fig = compare_experiments_across_slices(
        experiment_results=all_experiment_results,
        experiment_labels=experiment_labels,
        slice_label=slice_label,
        modulator_label=modulator_label,
        protein_label=protein_label,
        vis_config=vis_config
    )
    plt.show()

# Print summary statistics
print("\nDominance Analysis Summary:")
print("-" * 50)
for exp_results in all_experiment_results:
    exp_label = exp_results['experiment']
    print(f"Experiment: {exp_label}")
    
    for i, slice_result in enumerate(exp_results['slice_results']):
        slice_center = slice_result['slice_center']
        
        # Calculate statistics for this slice
        if slice_result['phase_sep_responses']:
            dominance_values = [1 - r for r in slice_result['phase_sep_responses']]
            avg_dominance = np.mean(dominance_values)
            max_dominance = np.max(dominance_values)
            min_dominance = np.min(dominance_values)
            
            print(f"  Slice {i+1} ({slice_label} = {slice_center:.2f}):")
            print(f"    Average Dominance: {avg_dominance:.3f}")
            print(f"    Maximum Dominance: {max_dominance:.3f}")
            print(f"    Minimum Dominance: {min_dominance:.3f}")
    
    print("-" * 50)

print("\nAnalysis complete!")