In [None]:
# =============================================================================
# IMPORTS AND SETUP
# =============================================================================
"""
Astrocyte Calcium Signaling Analysis Pipeline

This notebook processes AQuA2-exported calcium event data from astrocyte recordings,
applying feature-specific normalization and performing statistical comparisons
across experimental groups.

Workflow:
    1. Load raw CSV files from AQuA2 output
    2. Filter events by frame range (exclude artifacts)
    3. Normalize features relative to baseline condition
    4. Generate timepoint plots (Baseline -> Drug -> Washout)
    5. Perform statistical analysis (Kruskal-Wallis, Dunn's post-hoc, FDR correction)
    6. Export normalized data for use in analysis_time.ipynb

Experimental Groups:
    - WT: Wild Type (control)
    - AV: Antagonist Volinanserin (5-HT2A antagonist)
    - IP: IP3R2 cKO (calcium signaling knockout)
    - CE: CalEx (calcium exchanger manipulation)
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
from scipy import stats
from scipy.stats import kruskal, rankdata
from statsmodels.stats.multitest import multipletests

# Suppress RuntimeWarnings from statistical tests with small samples
warnings.filterwarnings('ignore', category=RuntimeWarning)


In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

# --- Analysis Parameters ---
MIN_FRAME = 20       # Start frame (excludes early recording artifacts)
MAX_FRAME = 100      # End frame (excludes late recording artifacts)
OUTPUT_DIR = 'Output__'  # Directory containing AQuA2 output
CHANNEL_SUFFIX = 'Ch1'   # AQuA2 channel identifier in filenames (set to '' if not used)
CWD = os.getcwd()

# --- Experimental Group Configuration ---
# Each group specifies:
#   - path: folder name in OUTPUT_DIR
#   - drug_suffix: filename suffix for drug condition (e.g., 'psi', 'psi+antag')
#   - slices: dict mapping data subfolders to slice numbers
DATA_CONFIG = {
    'WT': {
        'path': 'WT',
        'drug_suffix': 'psi',
        'slices': {
            'data1': [1, 2, 3],
            'data2': [1, 2, 3, 4],
        }
    },
    'AV': {
        'path': 'Antagonist- Volinanserin',
        'drug_suffix': 'psi+antag',
        'slices': {
            'data1': [1, 2, 3, 4],
            'data2': [1, 2, 3, 4],
        }
    },
    'IP': {
        'path': 'IP3R2 cKO',
        'drug_suffix': 'psi',
        'slices': {
            'data1': [2, 4, 5],
            'data2': [1, 2, 3],
        }
    },
    'CE': {
        'path': 'CalEx',
        'drug_suffix': 'psi',
        'slices': {
            'data1': [1, 2, 3, 4, 5, 6],
            'data2': [1, 2, 3],
        }
    },
}

# =============================================================================
# FEATURE CATEGORIES FOR NORMALIZATION
# =============================================================================
"""
Normalization Strategy:
    All features are normalized using fold-change relative to baseline median.
    
    FOLD-CHANGE (value / baseline_median):
        - Intensity/magnitude features (dF, dF/F, AUC)
        - Spatial features (Area, Perimeter, Circularity)
        - Temporal features (durations)
        - Network/count features (event counts)
        - Interpretation: 2.0 = doubled, 0.5 = halved
"""

# Fold-change features: value / baseline_median
FOLD_CHANGE_FEATURES = [
    "Curve - Max Df",
    "Curve - Max Dff",
    "Curve - dat AUC",
    "Curve - df AUC",
    "Curve - dff AUC",
    "Basic - Area",
    "Basic - Perimeter (only for 2D video)",
    "Basic - Circularity",  
    "Curve - Duration of visualized event overlay",
    "Curve - Duration 50% to 50% based on averge dF/F",
    "Curve - Duration 10% to 10% based on averge dF/F",
    "Curve - Rising duration 10% to 90% based on averge dF/F",
    "Curve - Decaying duration 90% to 10% based on averge dF/F",
    "Network - number of events in the same location",
    "Network - number of events in the same location with similar size only",
    "Network - maximum number of events appearing at the same time",
]

# Combined list of all features to analyze
ALL_FEATURES = FOLD_CHANGE_FEATURES


In [None]:
# =============================================================================
# DATA PROCESSING FUNCTIONS
# =============================================================================

def process_file(file_path):
    """
    Load and preprocess a single AQuA2 CSV export file.
    
    AQuA2 exports data in transposed format (features as rows, events as columns).
    This function:
        1. Transposes to standard format (events as rows)
        2. Drops irrelevant columns (Channel, Index, etc.)
        3. Converts all values to numeric
        4. Filters events by frame range
    
    Args:
        file_path: Path to AQuA2 CSV file
        
    Returns:
        DataFrame with events as rows and features as columns,
        or empty DataFrame if file doesn't exist or has no valid data
    """
    if not os.path.exists(file_path):
        return pd.DataFrame()

    # Load and transpose (AQuA2 exports features as rows)
    df = pd.read_csv(file_path, header=None)
    df_transposed = df.set_index(0).T.reset_index(drop=True)

    # Remove non-feature columns
    columns_to_drop = [
        'Channel',
        'Index',
        'Curve - P Value on max Dff (-log10)',
        'Curve - Decay tau'
    ]
    df_clean = df_transposed.drop(columns=columns_to_drop, axis=1, errors='ignore')

    # Convert all columns to numeric
    for col in df_clean.columns:
        df_clean[col] = pd.to_numeric(df_clean[col], errors='coerce')

    # Filter by frame range to exclude artifacts
    if 'Starting Frame' in df_clean.columns:
        df_filtered = df_clean[
            (df_clean['Starting Frame'] >= MIN_FRAME) &
            (df_clean['Starting Frame'] <= MAX_FRAME)
        ].copy()
    else:
        df_filtered = df_clean.copy()

    return df_filtered


def normalize_to_baseline(df, baseline_medians):
    """
    Normalize feature values using fold-change relative to baseline median.
    
    Formula: normalized_value = value / baseline_median
    
    Features with zero or NaN baseline medians are set to NaN.
    Infinite values (from edge cases) are replaced with NaN.
    
    Args:
        df: DataFrame with raw feature values
        baseline_medians: Series of median values from baseline condition
        
    Returns:
        DataFrame with fold-change normalized values
    """
    norm_df = df.copy()

    for col in df.columns:
        if col == "Starting Frame":
            continue

        baseline_val = baseline_medians.get(col, 0)
        
        if pd.isna(baseline_val) or baseline_val == 0:
            norm_df[col] = np.nan
        else:
            norm_df[col] = df[col] / baseline_val

    return norm_df.replace([np.inf, -np.inf], np.nan)


def analyze_slice(baseline_path, drug_path, washout_path):
    """
    Process and normalize a triplet of condition files for one slice.
    
    Args:
        baseline_path: Path to baseline condition CSV
        drug_path: Path to drug/PSI condition CSV
        washout_path: Path to washout condition CSV
        
    Returns:
        Tuple of (normalized_baseline, normalized_drug, normalized_washout) DataFrames
        Returns empty DataFrames if any file is missing or empty
    """
    df_base = process_file(baseline_path)
    df_drug = process_file(drug_path)
    df_wash = process_file(washout_path)

    # All three conditions required for analysis
    if df_base.empty or df_drug.empty or df_wash.empty:
        return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    # Calculate baseline medians for normalization (avoid division by zero)
    baseline_medians = df_base.median().replace(0, np.nan)

    # Normalize all conditions relative to baseline
    norm_base = normalize_to_baseline(df_base, baseline_medians)
    norm_drug = normalize_to_baseline(df_drug, baseline_medians)
    norm_wash = normalize_to_baseline(df_wash, baseline_medians)

    return norm_base, norm_drug, norm_wash


def build_file_paths(group_config):
    """
    Generate file path triplets for all slices in an experimental group.
    
    Args:
        group_config: Dict with 'path', 'drug_suffix', and 'slices' keys
        
    Returns:
        List of (baseline_path, drug_path, washout_path) tuples
    """
    triplets = []
    base_folder = group_config['path']
    drug_suffix = group_config['drug_suffix']
    
    # Build channel suffix part of filename
    ch = f'_{CHANNEL_SUFFIX}' if CHANNEL_SUFFIX else ''

    for data_subfolder, slice_nums in group_config['slices'].items():
        for slice_num in slice_nums:
            folder = os.path.join(CWD, OUTPUT_DIR, base_folder, data_subfolder)
            triplets.append((
                os.path.join(folder, f'slice{slice_num}_baseline_AQuA2{ch}.csv'),
                os.path.join(folder, f'slice{slice_num}_{drug_suffix}_AQuA2{ch}.csv'),
                os.path.join(folder, f'slice{slice_num}_washout_AQuA2{ch}.csv'),
            ))
    return triplets


def process_group(group_name):
    """
    Process all slices for an experimental group and combine results.
    
    Args:
        group_name: Key in DATA_CONFIG (e.g., 'WT', 'AV', 'IP', 'CE')
        
    Returns:
        Tuple of (combined_baseline, combined_drug, combined_washout) DataFrames
        with data from all slices concatenated
    """
    print(f"Processing Group: {group_name}...")
    config = DATA_CONFIG[group_name]
    triplets = build_file_paths(config)

    base_list, drug_list, wash_list = [], [], []

    for baseline_path, drug_path, washout_path in triplets:
        if os.path.exists(baseline_path):
            norm_base, norm_drug, norm_wash = analyze_slice(
                baseline_path, drug_path, washout_path
            )
            if not norm_base.empty:
                base_list.append(norm_base)
                drug_list.append(norm_drug)
                wash_list.append(norm_wash)
        else:
            print(f"  Missing: {baseline_path}")

    # Combine all slices
    if base_list:
        return (
            pd.concat(base_list, ignore_index=True),
            pd.concat(drug_list, ignore_index=True),
            pd.concat(wash_list, ignore_index=True)
        )
    return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()


def save_normalized_data(base_df, drug_df, wash_df, group_name):
    """
    Save normalized DataFrames to CSV for later use.
    
    Files are saved to: OUTPUT_DIR/<group_path>/<group>_<condition>_normalized.csv
    
    Args:
        base_df: Normalized baseline DataFrame
        drug_df: Normalized drug condition DataFrame
        wash_df: Normalized washout DataFrame
        group_name: Group identifier (e.g., 'WT')
    """
    folder_name = DATA_CONFIG[group_name]['path']
    output_path = os.path.join(CWD, OUTPUT_DIR, folder_name)
    os.makedirs(output_path, exist_ok=True)

    base_df.to_csv(os.path.join(output_path, f'{group_name}_baseline_normalized.csv'), index=False)
    drug_df.to_csv(os.path.join(output_path, f'{group_name}_drug_normalized.csv'), index=False)
    wash_df.to_csv(os.path.join(output_path, f'{group_name}_washout_normalized.csv'), index=False)
    
    print(f"  Saved normalized data to: {output_path}")


In [None]:
# =============================================================================
# STATISTICAL ANALYSIS FUNCTIONS
# =============================================================================

def dunns_test(groups, labels, control=None):
    """
    Perform Dunn's post-hoc test for pairwise comparisons after Kruskal-Wallis.
    
    Dunn's test compares groups using rank sums from the combined sample,
    with a z-test for each pair. This is the appropriate non-parametric
    post-hoc test following a significant Kruskal-Wallis result.
    
    Args:
        groups: List of arrays, one per group
        labels: List of group labels corresponding to each array
        control: If specified, only compare other groups against this control group.
                 If None, perform all pairwise comparisons.
    
    Returns:
        DataFrame with columns: Group1, Group2, Z, p_value, p_adjusted
        p_adjusted uses Benjamini-Hochberg FDR correction
    """
    # Combine all data and rank
    all_data = np.concatenate(groups)
    N = len(all_data)
    ranks = rankdata(all_data)
    
    # Calculate mean rank per group
    group_sizes = [len(g) for g in groups]
    mean_ranks = []
    start = 0
    for size in group_sizes:
        mean_ranks.append(np.mean(ranks[start:start + size]))
        start += size
    
    # Tie correction factor
    _, tie_counts = np.unique(ranks, return_counts=True)
    tie_correction = 1 - np.sum(tie_counts**3 - tie_counts) / (N**3 - N)
    
    # Avoid division by zero if all values are tied
    if tie_correction == 0:
        tie_correction = 1.0
    
    # Pairwise comparisons
    results = []
    
    if control is not None and control in labels:
        # Only compare against control
        ctrl_idx = labels.index(control)
        pairs = [(ctrl_idx, j) for j in range(len(labels)) if j != ctrl_idx]
    else:
        # All pairwise comparisons
        pairs = [(i, j) for i in range(len(labels)) for j in range(i+1, len(labels))]
    
    for i, j in pairs:
        n_i, n_j = group_sizes[i], group_sizes[j]
        
        # Standard error for the difference in mean ranks
        se = np.sqrt((N * (N + 1) / 12.0) * (1.0/n_i + 1.0/n_j) / tie_correction)
        
        if se == 0:
            z_stat = 0.0
            p_val = 1.0
        else:
            z_stat = (mean_ranks[i] - mean_ranks[j]) / se
            p_val = 2.0 * stats.norm.sf(abs(z_stat))  # Two-tailed
        
        results.append({
            'Group1': labels[i],
            'Group2': labels[j],
            'Z': z_stat,
            'p_value': p_val
        })
    
    results_df = pd.DataFrame(results)
    
    # FDR correction on pairwise p-values
    if len(results_df) > 0:
        _, results_df['p_adjusted'], _, _ = multipletests(
            results_df['p_value'], method='fdr_bh'
        )
    
    return results_df


def perform_statistical_tests(feature_name, group_data_dict, control_group='WT'):
    """
    Perform Kruskal-Wallis test with Dunn's post-hoc for one feature.
    
    Statistical approach:
        1. Kruskal-Wallis: Non-parametric omnibus test across all groups
           (appropriate for fold-change data which is typically right-skewed)
        2. Dunn's test: Pairwise comparisons against control group
           (only if Kruskal-Wallis is significant after FDR correction)
    
    Args:
        feature_name: Column name to analyze
        group_data_dict: Dict mapping group names to DataFrames
        control_group: Reference group for pairwise comparisons (default: 'WT')
        
    Returns:
        Dict containing test statistics, p-values, medians, and sample sizes
    """
    arrays = []
    labels = []
    group_keys = list(group_data_dict.keys())

    # Extract data for each group
    for name in group_keys:
        df = group_data_dict[name]
        if feature_name in df.columns:
            arr = df[feature_name].dropna().values
            arrays.append(arr)
            labels.append(name)
        else:
            arrays.append(np.array([]))
            labels.append(name)

    # Need at least 2 groups with >1 observation for comparison
    valid_indices = [i for i, arr in enumerate(arrays) if len(arr) > 1]

    if len(valid_indices) < 2:
        return {
            'feature': feature_name,
            'kruskal_h': np.nan,
            'kruskal_p': np.nan,
            'dunn_results': None,
            'group_keys': [],
            'note': 'Insufficient data'
        }

    # Filter to valid groups only
    valid_arrays = [arrays[i] for i in valid_indices]
    valid_labels = [labels[i] for i in valid_indices]

    # Kruskal-Wallis omnibus test
    kruskal_h, kruskal_p = kruskal(*valid_arrays)

    # Build results dictionary
    result = {
        'feature': feature_name,
        'kruskal_h': kruskal_h,
        'kruskal_p': kruskal_p,
        'dunn_results': None,
        'group_keys': valid_labels
    }

    # Add per-group statistics
    for i, name in enumerate(labels):
        arr = arrays[i]
        result[f'median_{name}'] = np.median(arr) if len(arr) > 0 else np.nan
        result[f'n_{name}'] = len(arr)

    return result


def analyze_all_features(group_dict, features_list, condition_name="Drug", 
                         control_group='WT', alpha=0.05):
    """
    Run Kruskal-Wallis on all features, apply FDR, then run Dunn's post-hoc.
    
    Workflow:
        1. Run Kruskal-Wallis for each feature
        2. Apply FDR correction across all features
        3. For features significant after FDR correction, run Dunn's post-hoc
    
    Args:
        group_dict: Dict mapping group names to DataFrames
        features_list: List of feature column names to analyze
        condition_name: Label for output (e.g., "Drug", "Washout")
        control_group: Reference group for Dunn's pairwise comparisons
        alpha: Significance threshold (default 0.05)
        
    Returns:
        DataFrame with test results for all features, including FDR-corrected p-values
    """
    print(f"\n{'=' * 80}")
    print(f"Statistical Analysis: {condition_name} Condition")
    print(f"Method: Kruskal-Wallis + Dunn's post-hoc (FDR corrected)")
    print(f"Comparing groups: {', '.join(group_dict.keys())}")
    print(f"Control group: {control_group}")
    print(f"{'=' * 80}\n")

    # Step 1: Run Kruskal-Wallis for each feature
    results = []
    for feature in features_list:
        res = perform_statistical_tests(feature, group_dict, control_group)
        results.append(res)

    results_df = pd.DataFrame(results)

    # Step 2: Apply FDR correction to Kruskal-Wallis p-values
    if not results_df.empty and 'kruskal_p' in results_df.columns:
        valid_p = results_df['kruskal_p'].dropna()
        if len(valid_p) > 0:
            _, results_df['kruskal_p_fdr'], _, _ = multipletests(
                results_df['kruskal_p'].fillna(1.0), method='fdr_bh'
            )

    # Step 3: Run Dunn's post-hoc ONLY for features significant after FDR
    if 'kruskal_p_fdr' in results_df.columns:
        for idx, row in results_df.iterrows():
            if pd.notna(row['kruskal_p_fdr']) and row['kruskal_p_fdr'] < alpha:
                valid_groups = row.get('group_keys', [])
                if len(valid_groups) >= 2:
                    # Get arrays for valid groups
                    arrays = []
                    labels = []
                    for name in valid_groups:
                        df = group_dict[name]
                        feature = row['feature']
                        if feature in df.columns:
                            arr = df[feature].dropna().values
                            if len(arr) > 1:
                                arrays.append(arr)
                                labels.append(name)
                    
                    if len(arrays) >= 2:
                        dunn_df = dunns_test(arrays, labels, control=control_group)
                        results_df.at[idx, 'dunn_results'] = dunn_df

    return results_df


def get_significance_stars(p_value):
    """Convert p-value to significance stars notation."""
    if pd.isna(p_value):
        return "N/A"
    if p_value < 0.001:
        return "***"
    if p_value < 0.01:
        return "**"
    if p_value < 0.05:
        return "*"
    return "ns"


def print_results_summary(results_df):
    """
    Print a formatted table of Kruskal-Wallis test results.
    
    Args:
        results_df: DataFrame from analyze_all_features()
    """
    if results_df.empty:
        return

    print("\n" + "=" * 100)
    print("KRUSKAL-WALLIS TEST RESULTS (FDR corrected)")
    print("=" * 100)

    # Get group names for header
    first_row = results_df.iloc[0]
    group_keys = first_row.get('group_keys', [])

    # Print header
    median_header = "Medians (" + "/".join(group_keys) + ")"
    n_header = "N (" + "/".join(group_keys) + ")"
    print(f"{'Feature':<50} {'H stat':<10} {'p (FDR)':<15} {'Sig':<5} {median_header:<30} {n_header}")
    print("-" * 140)

    # Print each feature's results
    for _, row in results_df.iterrows():
        feature = row['feature']

        # Kruskal-Wallis results
        h_stat = row.get('kruskal_h', np.nan)
        kruskal_p = row.get('kruskal_p_fdr', row.get('kruskal_p', np.nan))
        
        if pd.isna(kruskal_p):
            p_str = "N/A"
            sig = "N/A"
        else:
            p_str = f"{kruskal_p:.4f}"
            sig = get_significance_stars(kruskal_p)
        
        h_str = f"{h_stat:.2f}" if pd.notna(h_stat) else "N/A"

        # Group medians and sample sizes
        medians_list = []
        n_list = []
        for g in group_keys:
            val = row.get(f'median_{g}', np.nan)
            n_val = row.get(f'n_{g}', 0)
            medians_list.append("-" if pd.isna(val) else f"{val:.2f}")
            n_list.append(str(int(n_val)))
        median_str = "/".join(medians_list)
        n_str = "/".join(n_list)

        # Truncate long feature names
        feature_short = feature[:47] + "..." if len(feature) > 50 else feature
        print(f"{feature_short:<50} {h_str:<10} {p_str:<15} {sig:<5} {median_str:<30} {n_str}")

    print("=" * 100 + "\n")


def print_dunn_results(results_df, control_group='WT'):
    """
    Print Dunn's post-hoc test results for significant features.
    
    Args:
        results_df: DataFrame from analyze_all_features()
        control_group: Reference group for pairwise comparisons
    """
    # Filter to features with Dunn's results
    has_dunn = results_df[results_df['dunn_results'].apply(
        lambda x: x is not None and isinstance(x, pd.DataFrame) and len(x) > 0
    )]

    if len(has_dunn) == 0:
        print("\nNo significant features found for post-hoc testing.\n")
        return

    print("\n" + "=" * 80)
    print("DUNN'S POST-HOC TEST RESULTS (FDR corrected)")
    print(f"Pairwise comparisons against control: {control_group}")
    print("=" * 80)

    for _, row in has_dunn.iterrows():
        dunn_df = row['dunn_results']
        valid_groups = row.get('group_keys', [])
        
        print(f"\n{row['feature']}:")
        
        # Print group medians
        medians = [f"{g}={row.get(f'median_{g}', np.nan):.2f}" for g in valid_groups]
        print(f"  Group medians: {', '.join(medians)}")
        print(f"  Pairwise comparisons (vs {control_group}):")
        
        for _, comp in dunn_df.iterrows():
            sig = get_significance_stars(comp['p_adjusted'])
            print(f"    {comp['Group1']} vs {comp['Group2']}: "
                  f"Z={comp['Z']:.3f}, p={comp['p_adjusted']:.4f} {sig}")

    print("\n" + "=" * 80 + "\n")


In [None]:
# =============================================================================
# PLOTTING FUNCTIONS
# =============================================================================

def calculate_mad(series):
    """
    Calculate Median Absolute Deviation (MAD).
    
    MAD is a robust measure of variability, less sensitive to outliers than
    standard deviation. Calculated as: median(|x - median(x)|)
    
    Args:
        series: Pandas Series of numeric values
        
    Returns:
        MAD value (float)
    """
    return (series - series.median()).abs().median()


def calculate_iqr_half(series):
    """
    Calculate half Interquartile Range (IQR/2) for symmetric error bars.
    
    Returns (Q3-Q1)/2 for use as symmetric error around the median.
    
    Args:
        series: Pandas Series of numeric values
        
    Returns:
        Half-IQR value (float)
    """
    q75, q25 = np.percentile(series.dropna(), [75, 25])
    return (q75 - q25) / 2


def plot_timepoints(base_df, drug_df, wash_df, features_to_plot, title="Analysis", 
                    error_type='iqr'):
    """
    Plot feature values across experimental timepoints (Baseline -> Drug -> Washout).
    
    Creates a multi-panel figure showing median values (Â± error) for each feature
    across the three experimental conditions.
    
    Args:
        base_df: Normalized baseline condition DataFrame
        drug_df: Normalized drug condition DataFrame
        wash_df: Normalized washout DataFrame
        features_to_plot: List of feature column names to include
        title: Figure title
        error_type: Type of error bars (default: 'iqr')
            - 'mad': Median Absolute Deviation (robust, pairs with median)
            - 'iqr': Half Interquartile Range (robust, good for skewed data)
    """
    # Select error calculation function
    error_functions = {
        'mad': ('MAD', calculate_mad),
        'iqr': ('IQR/2', calculate_iqr_half),
    }
    error_label, error_func = error_functions.get(error_type, ('IQR/2', calculate_iqr_half))
    
    # Organize data by condition
    conditions = {'BASELINE': base_df, 'DRUG': drug_df, 'WASHOUT': wash_df}
    medians_dict = {}
    errors_dict = {}

    # Calculate statistics for each condition
    for name, df in conditions.items():
        if df.empty:
            continue
        numeric_cols = df.select_dtypes(include=['number']).columns
        medians_dict[name] = df[numeric_cols].median()
        errors_dict[name] = df[numeric_cols].apply(error_func)

    if not medians_dict:
        print(f"No data available for {title}")
        return

    # Combine into DataFrames for easy plotting
    combined_medians = pd.DataFrame(medians_dict)
    combined_errors = pd.DataFrame(errors_dict)

    # Filter to available features
    available_features = [f for f in features_to_plot if f in combined_medians.index]
    if not available_features:
        print(f"No features available for {title}")
        return

    # Setup figure grid
    n_cols = 4
    n_rows = (len(available_features) + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 3.5), 
                              constrained_layout=True)
    
    fig.suptitle(f"{title} (Median \u00B1 {error_label})", fontsize=12, fontweight='bold')

    # Handle single subplot case
    if len(available_features) == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    # Plot each feature
    for i, feature in enumerate(available_features):
        ax = axes[i]
        medians = combined_medians.loc[feature]
        errors = combined_errors.loc[feature]
        x_pos = range(len(medians))

        # Plot median points
        ax.plot(x_pos, medians, 'o', color='black', markersize=8, label='Median')

        # Add error bars
        for j, pos in enumerate(x_pos):
            ax.vlines(pos, medians.iloc[j] - errors.iloc[j], 
                     medians.iloc[j] + errors.iloc[j], color='grey', lw=2)

        ax.set_ylabel('Fold Change')
        ax.set_title(feature, fontsize=9)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(medians.index)
        ax.grid(True, linestyle='--', alpha=0.5)

    # Remove unused subplots
    for j in range(len(available_features), len(axes)):
        fig.delaxes(axes[j])

    plt.show()


def plot_group_comparison(group_dict, feature_name, condition="Drug", error_type='iqr'):
    """
    Create a bar chart comparing one feature across all experimental groups.
    
    Args:
        group_dict: Dict mapping group names to DataFrames
        feature_name: Column name to plot
        condition: Label for the condition (used in title)
        error_type: Type of error bars - 'mad' or 'iqr' (default: 'iqr')
    """
    error_functions = {
        'mad': ('MAD', calculate_mad),
        'iqr': ('IQR/2', calculate_iqr_half),
    }
    error_label, error_func = error_functions.get(error_type, ('IQR/2', calculate_iqr_half))
    
    groups = list(group_dict.keys())
    medians = []
    errors = []

    for group in groups:
        df = group_dict[group]
        if feature_name in df.columns and not df.empty:
            med = df[feature_name].median()
            err = error_func(df[feature_name])
            medians.append(med)
            errors.append(err)
        else:
            medians.append(0)
            errors.append(0)

    fig, ax = plt.subplots(figsize=(6, 4))
    x_pos = np.arange(len(groups))

    bars = ax.bar(x_pos, medians, yerr=errors, align='center', 
                  alpha=0.7, ecolor='black', capsize=10)

    for i, group in enumerate(groups):
        bars[i].set_color('gray' if group == 'WT' else 'skyblue')

    ax.set_ylabel(feature_name)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(groups)
    ax.set_title(f'{condition} Condition: {feature_name}\n(Median \u00B1 {error_label})')
    ax.yaxis.grid(True, linestyle='--', alpha=0.7)

    plt.tight_layout()
    plt.show()


In [None]:
# =============================================================================
# MAIN EXECUTION: DATA PROCESSING
# =============================================================================
"""
Process all experimental groups:
    1. Load raw AQuA2 CSV files for each slice
    2. Apply fold-change normalization relative to baseline
    3. Combine data across slices
    4. Save normalized data for downstream analysis
"""

# Storage for processed data
group_data = {
    'baseline': {},
    'drug': {},
    'washout': {}
}

# Process each experimental group
print("\n" + "=" * 60)
print("PROCESSING DATA")
print("=" * 60)

for group_name in DATA_CONFIG.keys():
    base_df, drug_df, wash_df = process_group(group_name)
    
    group_data['baseline'][group_name] = base_df
    group_data['drug'][group_name] = drug_df
    group_data['washout'][group_name] = wash_df
    
    # Save normalized data for use in analysis_time.ipynb
    if not base_df.empty:
        save_normalized_data(base_df, drug_df, wash_df, group_name)
    else:
        print(f"  Warning: No valid data for {group_name}")

print("\nData processing complete.")


In [None]:
# =============================================================================
# MAIN EXECUTION: VISUALIZATION
# =============================================================================
"""
Generate timepoint plots showing Baseline -> Drug -> Washout transitions
for each experimental group.

Error bar options:
    - 'iqr': Half Interquartile Range (robust, good for skewed fold-change data)
    - 'mad': Median Absolute Deviation (robust alternative)
"""

# Choose error type: 'iqr' or 'mad'
ERROR_TYPE = 'iqr'

print("\n" + "=" * 60)
print("GENERATING TIMEPOINT PLOTS")
print(f"Error bars: {ERROR_TYPE.upper()}")
print("=" * 60)

for group_name in DATA_CONFIG.keys():
    base_df = group_data['baseline'][group_name]
    drug_df = group_data['drug'][group_name]
    wash_df = group_data['washout'][group_name]
    
    if not drug_df.empty:
        print(f"\nPlotting: {group_name}")
        plot_timepoints(
            base_df, drug_df, wash_df,
            ALL_FEATURES,
            title=f'{group_name} Analysis',
            error_type=ERROR_TYPE
        )
    else:
        print(f"\nSkipping {group_name}: No data available")


In [None]:
# =============================================================================
# MAIN EXECUTION: STATISTICAL ANALYSIS
# =============================================================================
"""
Perform statistical comparisons across experimental groups:
    - Kruskal-Wallis (non-parametric omnibus test)
    - Dunn's post-hoc (pairwise comparisons vs WT control)
    - FDR correction (Benjamini-Hochberg) at both omnibus and post-hoc levels
"""

print("\n" + "=" * 60)
print("STATISTICAL ANALYSIS")
print("=" * 60)

# --- Analyze Drug Condition ---
print("\n>>> Analyzing DRUG condition...")
drug_results = analyze_all_features(
    group_data['drug'],
    ALL_FEATURES,
    condition_name="Drug",
    control_group='WT'
)
print_results_summary(drug_results)
print_dunn_results(drug_results, control_group='WT')

# --- Analyze Washout Condition ---
print("\n>>> Analyzing WASHOUT condition...")
washout_results = analyze_all_features(
    group_data['washout'],
    ALL_FEATURES,
    condition_name="Washout",
    control_group='WT'
)
print_results_summary(washout_results)
print_dunn_results(washout_results, control_group='WT')

print("\n" + "=" * 60)
print("ANALYSIS COMPLETE")
print("=" * 60)
