In [None]:
# %%
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import os
from typing import List, Tuple, Dict, Optional, Union
import glob

# Import the phase_dominance module
import phase_dominance as pd_analysis

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook", font_scale=1.2)

# %% [markdown]
# # Phase Dominance Analysis for 2D Systems
# 
# This script analyzes phase dominance in 2D systems (two primary components) across
# multiple experimental conditions. It processes multiple input files and creates
# comparative visualizations of dominance trends.

# %%
# Configuration section - Customize these parameters for your dataset

# File loading settings
file_config = {
    'data_directory': '',                  # Directory containing data files
    'file_pattern': '*.csv',                     # Pattern to match files (e.g., '*.csv')
    'experiment_labels': None,                   # Optional: List of labels for each experiment. If None, filenames will be used
}

# Column names in your dataset
column_config = {
    # Main component columns (required for 2D analysis)
    'comp_A_col': '647_conc_tot_CT',                # e.g., '546_conc_tot_CT' (Lysate)
    'comp_B_col': '488_conc_tot_CT',                # e.g., '647_conc_tot_CT' (G3BP1)
    
    # Phase separation state column
    'phase_sep_col': 'PS',                      # Column indicating phase separation state (boolean)
    
    # Intensity measurement columns
    'dilute_intensity_col': '5-25%_488_CT', # Raw dilute phase intensity  
    'total_intensity_col': '488_tot_CT',   # Raw total intensity
    
    # Additional columns
    'condensate_label_col': 'Condensate label',  # Column for different condensate types
    'frame_col': 'continuous_frame'              # Column for frame or time point information
}

# Component labels for plots
labels = {
    'comp_A_label': '',      # e.g., 'Lysate' in your original code
    'comp_B_label': '',      # e.g., 'G3BP1' in your original code
}

# Concentration limits for each component - set to None to auto-detect from data
limits = {
    'comp_A_lims': [30,80],                # Component A concentration range
    'comp_B_lims': [0,10],                # Component B concentration range
}

# Analysis configuration
analysis_config = {
    # Analysis parameters
    'grid_resolution': 30,              # Resolution of the grid for phase diagram visualization
    'min_droplets_per_cell': 5,         # Minimum number of droplets per grid cell
    'n_windows': 10,                    # Number of windows for dominance calculation
    'n_response_plots': 5,              # Number of response plots to show per experiment
    'min_data_threshold': 15,           # Minimum data points required for fitting
    'combined_phase_diagram': True,     # Whether to create a combined phase diagram for all experiments
    'protein_role': 'B',                # Which component to analyze for dominance ('A' or 'B')
}

# Visualization settings
vis_config = {
    'phase_diagram_cmap': 'coolwarm',   # Colormap for phase diagrams
    'dominance_color_palette': 'viridis', # Color palette for dominance plots
    'figure_dpi': 150,                  # DPI for saved figures
    'save_figures': True,               # Whether to save figures
    'output_directory': './outputs',    # Directory for saved figures
}


In [None]:

# %%
# Helper functions for the analysis

def load_experimental_data(config: Dict) -> Tuple[List[pd.DataFrame], List[str]]:
    """
    Load multiple experimental datasets from files.
    
    Parameters:
    ----------
    config : Dict
        Configuration dictionary with file loading settings
    
    Returns:
    -------
    Tuple[List[pd.DataFrame], List[str]]
        List of dataframes and corresponding experiment labels
    """
    # Get list of files matching pattern
    file_pattern = os.path.join(config['data_directory'], config['file_pattern'])
    file_list = sorted(glob.glob(file_pattern))
    
    if not file_list:
        # If no files found, generate demo data
        print(f"No files found matching {file_pattern}. Generating demo data...")
        datasets = [generate_demo_data(n_points=5000) for _ in range(3)]
        
        # Create demo labels
        if config['experiment_labels'] is None:
            labels = [f"Demo_Experiment_{i+1}" for i in range(len(datasets))]
        else:
            labels = config['experiment_labels']
            if len(labels) < len(datasets):
                labels.extend([f"Demo_Experiment_{i+1}" for i in range(len(labels), len(datasets))])
    else:
        # Load actual data files
        print(f"Found {len(file_list)} files matching pattern.")
        datasets = []
        
        for file_path in file_list:
            try:
                df = pd.read_csv(file_path)
                datasets.append(df)
                print(f"Loaded {os.path.basename(file_path)} with {len(df)} rows")
            except Exception as e:
                print(f"Error loading {file_path}: {str(e)}")
        
        # Create labels from filenames if not provided
        if config['experiment_labels'] is None:
            labels = [os.path.splitext(os.path.basename(f))[0] for f in file_list]
        else:
            labels = config['experiment_labels']
            if len(labels) < len(datasets):
                base_labels = [os.path.splitext(os.path.basename(f))[0] for f in file_list]
                labels.extend(base_labels[len(labels):])
    
    return datasets, labels

def generate_demo_data(n_points=5000, seed_offset=0):
    """Generate a simulated dataset for demonstration purposes."""
    np.random.seed(42 + seed_offset)
    
    # Generate random concentrations
    comp_A = np.random.uniform(0, 2, n_points)
    comp_B = np.random.uniform(10, 40, n_points)
    
    # Define a simple phase separation model with different parameters for each demo dataset
    phase_sep_prob = 1 / (1 + np.exp(-((comp_A - 1 + 0.2*seed_offset) * 5 + 
                                      (comp_B - 25 - seed_offset) * 0.2)))
    phase_separated = np.random.random(n_points) < phase_sep_prob
    
    # Calculate dilute phase concentrations
    dilute_factor_A = np.where(phase_separated, 
                              0.5 + 0.3 * np.random.random(n_points) - (0.01 + 0.005*seed_offset) * comp_B, 
                              0.95 + 0.05 * np.random.random(n_points))
    
    dilute_A = comp_A * np.clip(dilute_factor_A, 0.1, 0.99)
    normalized_dilute_A = dilute_A / comp_A
    
    # Create DataFrame
    df = pd.DataFrame({
        'comp_A_conc': comp_A,
        'comp_B_conc': comp_B,
        'PS': phase_separated,
        'dilute_A_conc': dilute_A,
        'Normalised dilute intensity': normalized_dilute_A,
        'Condensate label': 'DefaultCondensate',
        'continuous_frame': np.random.randint(0, 100, n_points)
    })
    
    # Add a 'Phase separated' column as a copy of 'PS'
    df['Phase separated'] = df['PS']
    
    return df

def determine_concentration_limits(datasets: List[pd.DataFrame], 
                                  col_names: Dict, 
                                  config_limits: Dict) -> Dict:
    """
    Determine concentration limits for components, using either the provided 
    limits or auto-detecting from the data.
    
    Parameters:
    ----------
    datasets : List[pd.DataFrame]
        List of dataframes containing the experimental data
    col_names : Dict
        Dictionary of column names
    config_limits : Dict
        User-provided concentration limits (may contain None values)
    
    Returns:
    -------
    Dict
        Dictionary with determined concentration limits
    """
    comp_A_col = col_names['comp_A_col']
    comp_B_col = col_names['comp_B_col']
    
    # Initialize with user-provided limits
    limits_out = config_limits.copy()
    
    # Determine Component A limits if not provided
    if limits_out['comp_A_lims'] is None:
        min_val = min([df[comp_A_col].min() for df in datasets])
        max_val = max([df[comp_A_col].max() for df in datasets])
        # Add a small buffer
        range_val = max_val - min_val
        buffer = range_val * 0.05
        limits_out['comp_A_lims'] = (max(0, min_val - buffer), max_val + buffer)
    
    # Determine Component B limits if not provided
    if limits_out['comp_B_lims'] is None:
        min_val = min([df[comp_B_col].min() for df in datasets])
        max_val = max([df[comp_B_col].max() for df in datasets])
        # Add a small buffer
        range_val = max_val - min_val
        buffer = range_val * 0.05
        limits_out['comp_B_lims'] = (max(0, min_val - buffer), max_val + buffer)
    
    return limits_out

def plot_combined_phase_diagram(datasets: List[pd.DataFrame], 
                              experiment_labels: List[str],
                              col_names: Dict, 
                              comp_labels: Dict,
                              limits: Dict,
                              config: Dict) -> plt.Figure:
    """
    Create a combined phase diagram visualization for all experiments.
    
    Parameters:
    ----------
    datasets : List[pd.DataFrame]
        List of dataframes containing the experimental data
    experiment_labels : List[str]
        Labels for each experiment
    col_names : Dict
        Dictionary of column names
    comp_labels : Dict
        Dictionary of component labels for plots
    limits : Dict
        Dictionary with concentration limits
    config : Dict
        Visualization configuration
    
    Returns:
    -------
    plt.Figure
        The generated figure
    """
    # Extract relevant configuration
    comp_A_col = col_names['comp_A_col']
    comp_B_col = col_names['comp_B_col']
    phase_sep_col = col_names['phase_sep_col']
    
    comp_A_label = comp_labels['comp_A_label']
    comp_B_label = comp_labels['comp_B_label']
    
    comp_A_lims = limits['comp_A_lims']
    comp_B_lims = limits['comp_B_lims']
    
    cmap = plt.get_cmap(config['phase_diagram_cmap'])
    
    # Create figure
    fig, axes = plt.subplots(1, len(datasets), figsize=(5*len(datasets), 5), 
                            sharex=True, sharey=True)
    
    # Handle single dataset case
    if len(datasets) == 1:
        axes = [axes]
    
    # Plot each dataset
    for i, (df, label, ax) in enumerate(zip(datasets, experiment_labels, axes)):
        # Plot shaded grid
        ax = pd_analysis.plot_shaded_grid(
            df=df,
            x_col=comp_B_col,
            y_col=comp_A_col,
            avg_col=phase_sep_col,
            x_lim=comp_B_lims,
            y_lim=comp_A_lims,
            N=analysis_config['grid_resolution'],
            minDroplets=analysis_config['min_droplets_per_cell'],
            cmap=cmap,
            vlims=(0, 1),
            ax=ax
        )
        
        ax.set_title(label)
        ax.set_xlabel(f'{comp_B_label} Concentration')
        ax.set_ylabel(f'{comp_A_label} Concentration')
    
    plt.tight_layout()
    
    # Save figure if configured
    if vis_config['save_figures']:
        os.makedirs(vis_config['output_directory'], exist_ok=True)
        fig_path = os.path.join(vis_config['output_directory'], 'combined_phase_diagram.png')
        plt.savefig(fig_path, dpi=vis_config['figure_dpi'], bbox_inches='tight')
        print(f"Saved combined phase diagram to {fig_path}")
    
    return fig

def plot_response_difference(
    results: List[Dict],
    modulator_label: str = "Modulator",
    protein_label: str = "Component",
    figsize: Tuple[int, int] = (10, 6),
    custom_colors: Optional[List] = None,
    show_errors: bool = True,
    confidence_level: float = 0.95,
    legend_loc: str = 'best'
) -> plt.Figure:
    """
    Plot the difference between non-phase separated and phase separated responses.
    
    Parameters:
    ----------
    results : List[Dict]
        List of result dictionaries from analyze_and_plot_dominance.
    modulator_label : str, optional
        Label for the modulator axis. Defaults to "Modulator".
    protein_label : str, optional
        Label for the protein component. Defaults to "Component".
    figsize : Tuple[int, int], optional
        Figure size. Defaults to (10, 6).
    custom_colors : Optional[List], optional
        Custom color palette. If None, a default colormap is used.
    show_errors : bool, optional
        Whether to show error bands. Defaults to True.
    confidence_level : float, optional
        Confidence level for error bands. Defaults to 0.95.
    legend_loc : str, optional
        Location of the legend. Defaults to 'best'.
    
    Returns:
    -------
    plt.Figure
        The generated matplotlib figure.
    """
    if custom_colors is None:
        colors = sns.color_palette('viridis', len(results))
    else:
        colors = custom_colors
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Calculate appropriate z-score for the confidence level
    from scipy import stats
    z_score = stats.norm.ppf((1 + confidence_level) / 2)
    
    # Extract experiment labels
    experiment_labels = [result['experiment'] for result in results]
    
    for i, result in enumerate(results):
        ps_responses = result.get('phase_sep_responses', [])
        nps_responses = result.get('non_phase_sep_responses', [])
        ps_mod_values = result.get('phase_sep_modulator_values', [])
        nps_mod_values = result.get('non_phase_sep_modulator_values', [])
        ps_errors = result.get('phase_sep_errors', [])
        nps_errors = result.get('non_phase_sep_errors', [])
        
        # Find common modulator values where we have both PS and NPS data
        common_mod_values = []
        ps_indices = []
        nps_indices = []
        
        for j, mod_val in enumerate(ps_mod_values):
            if mod_val in nps_mod_values:
                nps_idx = nps_mod_values.index(mod_val)
                common_mod_values.append(mod_val)
                ps_indices.append(j)
                nps_indices.append(nps_idx)
        
        if not common_mod_values:
            print(f"No common modulator values found for {result['experiment']}")
            continue
        
        # Calculate the differences (NPS - PS)
        differences = [nps_responses[nps_idx] - ps_responses[ps_idx] 
                      for ps_idx, nps_idx in zip(ps_indices, nps_indices)]
        
        # Calculate combined error if available
        if show_errors and ps_errors and nps_errors:
            combined_errors = [np.sqrt(ps_errors[ps_idx]**2 + nps_errors[nps_idx]**2) 
                              for ps_idx, nps_idx in zip(ps_indices, nps_indices)]
            
            # Calculate upper and lower confidence bounds
            lower_bound = [diff - z_score * err for diff, err in zip(differences, combined_errors)]
            upper_bound = [diff + z_score * err for diff, err in zip(differences, combined_errors)]
        else:
            combined_errors = None
        
        # Plot the difference curve
        ax.plot(
            common_mod_values, 
            differences,
            '-o', 
            color=colors[i],
            label=experiment_labels[i],
            linewidth=2,
            markersize=5
        )
        
        # Add confidence intervals if error values provided and show_errors is True
        if show_errors and combined_errors:
            ax.fill_between(
                common_mod_values,
                lower_bound,
                upper_bound,
                color=colors[i],
                alpha=0.2
            )
    
    # Add a zero line for reference
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
    
    # Set up the plot
    ax.set_xlabel(f'{modulator_label} Concentration')
    ax.set_ylabel(f'Response Difference (Mixed - Demixed)')
    ax.grid(True, alpha=0.3)
    ax.legend(loc=legend_loc)
    
    # Add confidence level to title if showing errors
    title = "Difference Between Mixed and Demixed Responses"
    if show_errors:
        title += f" with {int(confidence_level*100)}% Confidence Intervals"
    ax.set_title(title)
    
    plt.tight_layout()
    return fig

def analyze_and_plot_dominance(datasets: List[pd.DataFrame], 
                              experiment_labels: List[str],
                              col_names: Dict, 
                              comp_labels: Dict,
                              limits: Dict,
                              analysis_config: Dict,
                              vis_config: Dict) -> Tuple[plt.Figure, List[Dict]]:
    """
    Analyze dominance across multiple experiments and create a comparative plot with confidence intervals.
    
    Parameters:
    ----------
    datasets : List[pd.DataFrame]
        List of dataframes containing the experimental data
    experiment_labels : List[str]
        Labels for each experiment
    col_names : Dict
        Dictionary of column names
    comp_labels : Dict
        Dictionary of component labels for plots
    limits : Dict
        Dictionary with concentration limits
    analysis_config : Dict
        Analysis configuration
    vis_config : Dict
        Visualization configuration
    
    Returns:
    -------
    Tuple[plt.Figure, List[Dict]]
        The generated figure and a list of results dictionaries
    """
    # Extract relevant configuration
    comp_A_col = col_names['comp_A_col']
    comp_B_col = col_names['comp_B_col']
    phase_sep_col = col_names['phase_sep_col']
    dilute_intensity_col = col_names['dilute_intensity_col']
    total_intensity_col = col_names['total_intensity_col']
    condensate_label_col = col_names['condensate_label_col']
    
    comp_A_label = comp_labels['comp_A_label']
    comp_B_label = comp_labels['comp_B_label']
    
    comp_A_lims = limits['comp_A_lims']
    comp_B_lims = limits['comp_B_lims']
    
    # Determine protein and modulator based on configuration
    if analysis_config['protein_role'] == 'A':
        protein_col = comp_A_col
        protein_lims = comp_A_lims
        protein_label = comp_A_label
        modulator_col = comp_B_col
        modulator_lims = comp_B_lims
        modulator_label = comp_B_label
    else:  # 'B'
        protein_col = comp_B_col
        protein_lims = comp_B_lims
        protein_label = comp_B_label
        modulator_col = comp_A_col
        modulator_lims = comp_A_lims
        modulator_label = comp_A_label
    
    # Color palette for dominance lines
    if 'dominance_color_palette' in vis_config and vis_config['dominance_color_palette'] != 'default':
        colors = sns.color_palette(vis_config['dominance_color_palette'], len(datasets))
    else:
        colors = sns.color_palette('viridis', len(datasets))
    
    # Set up phase palette for dominance sweep plots
    phase_sep_color_palette = sns.color_palette('RdBu_r', 5)
    tieline_color_demixed = [phase_sep_color_palette[-1], phase_sep_color_palette[-3]]
    tieline_color_mixed = [phase_sep_color_palette[2], phase_sep_color_palette[0]]
    
    # Lists to store results
    all_responses = []
    all_response_errors = []  # Add this to store error values
    all_modulator_values = []
    all_results = []
    
    # Analyze each dataset
    for i, (df, label) in enumerate(zip(datasets, experiment_labels)):
        print(f"Analyzing dominance for {label}...")
        
        # Ensure condensate label column exists
        if condensate_label_col not in df.columns:
            df[condensate_label_col] = 'DefaultCondensate'
        
        # Calculate dominance
        phase_sep_responses, non_phase_sep_responses, phase_sep_modulator_values, non_phase_sep_modulator_values, phase_sep_errors, non_phase_sep_errors = pd_analysis.dominance_sweep(
            df=df,
            protein_conc_col=protein_col,
            protein_min=protein_lims[0],
            protein_max=protein_lims[1],
            modulator_conc_col=modulator_col,
            modulator_min=modulator_lims[0],
            modulator_max=modulator_lims[1],
            condensate_label=df[condensate_label_col].iloc[0],
            numberWindows=analysis_config['n_windows'],
            numberResponsePlots=analysis_config['n_response_plots'],
            minimumDataThreshold=analysis_config['min_data_threshold'],
            modulator_label=modulator_label,
            protein_label=protein_label,
            phase_sep_col=phase_sep_col,
            dilute_intensity_col=dilute_intensity_col,
            total_intensity_col=total_intensity_col,
            tieline_color_demixed=tieline_color_demixed,
            tieline_color_mixed=tieline_color_mixed
        )
        
        # Store results
        all_responses.append(phase_sep_responses)
        all_response_errors.append(phase_sep_errors)  # Store the errors
        all_modulator_values.append(phase_sep_modulator_values)
        all_results.append({
            'experiment': label,
            'phase_sep_responses': phase_sep_responses,
            'phase_sep_errors': phase_sep_errors,  # Include errors in results
            'non_phase_sep_responses': non_phase_sep_responses,
            'non_phase_sep_errors': non_phase_sep_errors,
            'phase_sep_modulator_values': phase_sep_modulator_values,
            'non_phase_sep_modulator_values': non_phase_sep_modulator_values
        })
    
    # Create comparative dominance plot with confidence intervals
    fig = pd_analysis.plot_dominance_comparison(
        response_list=all_responses,
        modulator_values_list=all_modulator_values,
        response_errors_list=all_response_errors,  # Pass errors for confidence intervals
        experiment_labels=experiment_labels,
        modulator_label=modulator_label,
        response_component_label=protein_label,
        figsize=(6, 3),
        custom_colors=colors,
        confidence_level=0.95,  # 95% confidence intervals
        show_errors=True
    )
    
    # Save figure if configured
    if vis_config['save_figures']:
        os.makedirs(vis_config['output_directory'], exist_ok=True)
        fig_path = os.path.join(vis_config['output_directory'], 'comparative_dominance.png')
        plt.savefig(fig_path, dpi=vis_config['figure_dpi'], bbox_inches='tight')
        print(f"Saved comparative dominance plot to {fig_path}")
    
    return fig, all_results
# %%
# Main execution block

# Create output directory if needed
if vis_config['save_figures']:
    os.makedirs(vis_config['output_directory'], exist_ok=True)

# Load experimental data
print("Loading experimental data...")
datasets, experiment_labels = load_experimental_data(file_config)
print(f"Loaded {len(datasets)} datasets with labels: {experiment_labels}")

# Determine concentration limits
limits = determine_concentration_limits(datasets, column_config, limits)
print(f"Concentration limits - Component A: {limits['comp_A_lims']}, Component B: {limits['comp_B_lims']}")

# Create combined phase diagram if configured
if analysis_config['combined_phase_diagram']:
    print("Creating combined phase diagram...")
    phase_fig = plot_combined_phase_diagram(
        datasets,
        experiment_labels,
        column_config,
        labels,
        limits,
        vis_config
    )
    plt.show()

# Analyze dominance and create comparative plot
print("Analyzing dominance across all experiments...")
dominance_fig, results = analyze_and_plot_dominance(
    datasets,
    experiment_labels,
    column_config,
    labels,
    limits,
    analysis_config,
    vis_config
)
plt.show()

difference_fig = plot_response_difference(
    results,
    modulator_label=modulator_label if 'modulator_label' in locals() else "Modulator",
    protein_label=protein_label if 'protein_label' in locals() else "Component",
    custom_colors=colors if 'colors' in locals() else None,
    confidence_level=0.95,
    show_errors=True
)

# Print summary statistics
print("\nDominance Analysis Summary:")
print("-" * 50)
for result in results:
    exp_name = result['experiment']
    avg_dominance = np.mean([1 - r for r in result['phase_sep_responses']])
    max_dominance = np.max([1 - r for r in result['phase_sep_responses']])
    min_dominance = np.min([1 - r for r in result['phase_sep_responses']])
    
    print(f"Experiment: {exp_name}")
    print(f"  Average Dominance: {avg_dominance:.3f}")
    print(f"  Maximum Dominance: {max_dominance:.3f}")
    print(f"  Minimum Dominance: {min_dominance:.3f}")
    print("-" * 50)

print("\nAnalysis complete!")