In [None]:
import os
import sys

# Set calculation parameters if you are using Windows
_original_omp_num_threads_ = None
if sys.platform == 'win32':
    _original_omp_num_threads_ = os.environ.get('OMP_NUM_THREADS')
    os.environ['OMP_NUM_THREADS'] = '1'

import numpy as np
import matplotlib.pyplot as plt
import re
from collections import defaultdict
from scipy.spatial.distance import euclidean
import pandas as pd
from matplotlib.lines import Line2D
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # For LDA

def ep_read(filename):
    """
    Read data from a Cartool .ep file.

    This function reads a text file where each line contains space-separated
    floating-point numbers. It skips any empty lines.

    Parameters
    ----------
    filename : str
        The path to the .ep file.

    Returns
    -------
    numpy.ndarray
        A 2D NumPy array containing the data from the file.
        Each row corresponds to a line in the file, and each column
        corresponds to a number in that line.
    """
    with open(filename, 'r') as file:
        data = np.array([list(map(float, line.split())) for line in file if line.strip()])
    return data

def read_cartool_xyz(xyz_file):
    """
    Read electrode positions and names from a Cartool .xyz file.

    The Cartool .xyz file format typically starts with a header line
    containing the number of electrodes. Subsequent lines contain
    the X, Y, Z coordinates and the electrode name. 

    Parameters
    ----------
    xyz_file : str
        The path to the .xyz file.

    Returns
    -------
    positions : numpy.ndarray
        A 2D NumPy array of shape (n_electrodes, 3) containing the
        X, Y, Z coordinates of each electrode after transformation.
        Row 'i' corresponds to the coordinates of the i-th electrode.
    ch_names : list of str
        A list of strings, where each string is the name of an electrode.
        The order of names corresponds to the rows in the 'positions' array.
    """
    with open(xyz_file, 'r') as f:
        lines = f.readlines()
    header = lines[0].strip().split()
    n_electrodes = int(header[0])
    positions = np.zeros((n_electrodes, 3))
    ch_names = []
    for i in range(n_electrodes):
        line = lines[i+1].strip().split()
        positions[i, 0] = float(line[0])
        positions[i, 1] = -float(line[1])
        positions[i, 2] = float(line[2])
        ch_names.append(line[3])
    return positions, ch_names

def plot_outliers(output_base_dir):
    """
    Detect microstate topographic outliers defined by a Linear Discriminant Analysis (LDA) model trained 
    to separate outliers (High same class distance (SCD), Low different class distance (DCD) via medians) 
    from non-outliers (Low SCD, High DCD via medians).
    """

    plt.rcParams.update({
        'font.size': 18, 'axes.titlesize': 18, 'axes.labelsize': 18,
        'xtick.labelsize': 18, 'ytick.labelsize': 18, 'legend.fontsize': 16
    })

    ms_labels_map = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
    groups = ['HC', 'Young', 'ADNoEp', 'ADEp']
    conditions = ['Awake', 'N2', 'REM']

    results_dir = os.path.join(output_base_dir, f'outlier_plots_lda')
    os.makedirs(results_dir, exist_ok=True)

    all_outliers_dict = {}
    all_non_outliers_dict = {}
    all_data_list = []

    for group in groups:
        all_outliers_dict[group] = {}
        all_non_outliers_dict[group] = {}
        for condition in conditions:
            input_dir = os.path.join(output_base_dir, f'{group}_{condition}_median')
            plot_dir = os.path.join(results_dir, group)
            os.makedirs(plot_dir, exist_ok=True)
            print(f"Processing {group} - {condition}...")

            if not os.path.exists(input_dir):
                print(f"  Directory not found: {input_dir}, skipping.")
                continue
            subject_files = [f for f in os.listdir(input_dir)
                             if f.startswith('Subject_') and f.endswith('_median_microstates.ep')]
            if not subject_files:
                print(f"  No subject files found in {input_dir}, skipping.")
                continue
            subject_data = {}
            for filename in subject_files:
                match = re.search(r'Subject_(\d+)_median_microstates\.ep', filename)
                if not match: continue
                subject_id = match.group(1)
                file_path = os.path.join(input_dir, filename)
                try:
                    data_content = ep_read(file_path)
                    if data_content.shape[0] == 4: subject_data[subject_id] = data_content
                except Exception as e: print(f"  Error reading {filename}: {e}")
            if not subject_data:
                print(f"  No valid subject data found for {group} - {condition}, skipping.")
                continue
            
            all_patterns = {ms_val: [] for ms_val in range(4)}
            for subj_patterns in subject_data.values():
                for ms_val in range(4): all_patterns[ms_val].append(subj_patterns[ms_val])
            group_medians_reference = {}
            for ms_val in range(4):
                if all_patterns[ms_val]: group_medians_reference[ms_val] = np.median(np.array(all_patterns[ms_val]), axis=0)
            
            condition_records = []
            for subject_id, patterns in subject_data.items():
                for ms_class in range(4):
                    if ms_class not in group_medians_reference: continue
                    ms_label = ms_labels_map[ms_class]
                    same_class_dist = euclidean(patterns[ms_class], group_medians_reference[ms_class])
                    other_dists = [euclidean(patterns[ms_class], group_medians_reference[other_ms])
                                   for other_ms in range(4) if other_ms != ms_class and other_ms in group_medians_reference]
                    avg_other_dist = np.mean(other_dists) if other_dists else np.nan
                    condition_records.append({
                        'Group': group, 'Condition': condition, 'Subject ID': subject_id,
                        'Microstate': ms_label, 'Same Class Distance': same_class_dist,
                        'Diff Class Distance': avg_other_dist
                    })
            if not condition_records:
                print(f"  No records to process for {group} - {condition}.")
                continue
            
            df_condition = pd.DataFrame(condition_records)
            df_condition.dropna(subset=['Same Class Distance', 'Diff Class Distance'], inplace=True)
            
            df_condition['Is Outlier'] = False # Default
            median_scd_for_plot = np.nan
            median_dcd_for_plot = np.nan
            lda_model = None
            scaler_model = None


            if not df_condition.empty and len(df_condition) > 2: # Need enough points for medians and LDA
                features_for_lda = df_condition[['Same Class Distance', 'Diff Class Distance']].values
                scaler_model = StandardScaler()
                scaled_features = scaler_model.fit_transform(features_for_lda)
                
                median_scd_for_plot = df_condition['Same Class Distance'].median()
                median_dcd_for_plot = df_condition['Diff Class Distance'].median()

                if pd.notna(median_scd_for_plot) and pd.notna(median_dcd_for_plot):
    
                    # Use original (unscaled) features for median comparison, then use scaled for LDA
                    is_pno = (df_condition['Same Class Distance'] <= median_scd_for_plot) & \
                             (df_condition['Diff Class Distance'] >= median_dcd_for_plot)
                    is_po = (df_condition['Same Class Distance'] > median_scd_for_plot) & \
                            (df_condition['Diff Class Distance'] < median_dcd_for_plot)
                    
                    # Select data for LDA training (only points in these two quadrants)
                    lda_train_mask = is_pno | is_po
                    X_lda_train = scaled_features[lda_train_mask]
                    y_lda_train = np.zeros(len(df_condition)) # Temp array
                    y_lda_train[is_pno] = 0 # Non-outlier class
                    y_lda_train[is_po] = 1  # Outlier class
                    y_lda_train = y_lda_train[lda_train_mask]

                    if len(X_lda_train) > 0 and len(np.unique(y_lda_train)) == 2: # Need at least two classes for LDA
                        try:
                            lda_model = LinearDiscriminantAnalysis()
                            lda_model.fit(X_lda_train, y_lda_train)
                            
                            # Predict on ALL scaled features
                            all_predictions_lda = lda_model.predict(scaled_features)
                            df_condition['Is Outlier'] = (all_predictions_lda == 1) # Label 1 was an outlier
                            
                            num_final_outliers = df_condition['Is Outlier'].sum()
                            print(f"  LDA trained on groups. Medians (SCD:{median_scd_for_plot:.2f}, DCD:{median_dcd_for_plot:.2f}). Found {num_final_outliers} final outliers.")
                        except Exception as e_lda:
                            print(f"  Error during LDA for {group} - {condition}: {e_lda}. No outliers assigned by LDA.")
                            df_condition['Is Outlier'] = False # Fallback
                    else:
                        print(f"  Warning: Not enough data or only one class for LDA training in {group} - {condition} (PNO: {is_pno.sum()}, PO: {is_po.sum()}). Using simple median quadrant as fallback.")
                        # Fallback to simple median quadrant if LDA cannot be trained
                        df_condition['Is Outlier'] = (df_condition['Same Class Distance'] > median_scd_for_plot) & \
                                                     (df_condition['Diff Class Distance'] < median_dcd_for_plot)
                        num_fallback_outliers = df_condition['Is Outlier'].sum()
                        print(f"    Fallback to median quadrant: found {num_fallback_outliers} outliers.")
                        lda_model = None # Ensure no LDA line is plotted

                else:
                    print(f"  Warning: Could not calculate SCD/DCD medians for {group} - {condition}. No outliers marked.")
            
            elif df_condition.empty:
                print(f"  DataFrame empty for {group} - {condition} after NaN drop. No outliers.")
            else: # len(df_condition) <=2
                print(f"  Not enough data points ({len(df_condition)}) for LDA in {group} - {condition}. No outliers.")


            all_data_list.extend(df_condition.to_dict('records'))
            condition_outliers = {ms_labels_map[ms_val]: [] for ms_val in range(4)}
            condition_non_outliers = {ms_labels_map[ms_val]: [] for ms_val in range(4)}
            for _, row in df_condition.iterrows():
                if row.get('Is Outlier', False):
                    condition_outliers[row['Microstate']].append(str(row['Subject ID']))
                else:
                    condition_non_outliers[row['Microstate']].append(str(row['Subject ID']))
            all_outliers_dict[group][condition] = condition_outliers
            all_non_outliers_dict[group][condition] = condition_non_outliers

            if df_condition.empty:
                continue
            
            plt.figure(figsize=(12, 12))
            padding_factor = 1.1
            x_data_max_val = df_condition['Same Class Distance'][np.isfinite(df_condition['Same Class Distance'])].max() if not df_condition['Same Class Distance'][np.isfinite(df_condition['Same Class Distance'])].empty else 0.0
            y_data_max_val = df_condition['Diff Class Distance'][np.isfinite(df_condition['Diff Class Distance'])].max() if not df_condition['Diff Class Distance'][np.isfinite(df_condition['Diff Class Distance'])].empty else 0.0
            base_max_for_scaling = max(x_data_max_val, y_data_max_val, 0.1) 
            
            plot_xlim_upper = base_max_for_scaling * padding_factor
            plot_ylim_upper = base_max_for_scaling * 2.0 * padding_factor 
            if plot_xlim_upper <=0: plot_xlim_upper = 1.0
            if plot_ylim_upper <=0: plot_ylim_upper = 2.0

            plt.xlim(0, plot_xlim_upper)
            plt.ylim(0, plot_ylim_upper)
            plt.axis('equal') 
            
            final_xlims = plt.xlim()
            final_ylims = plt.ylim()
            # text_baseline_y_for_anno = final_ylims[0] + (final_ylims[1] - final_ylims[0]) * 0.01
            # Select text baseline for arrow annotations
            text_baseline_y_for_anno = 0.8

            ms_colors_map = {'A': 'blue', 'B': 'green', 'C': 'orange', 'D': 'purple'}
            legend_handles = [Line2D([0], [0], marker='o', color='w', markerfacecolor=col, markersize=8, label=f'MS {ms_lab}')
                              for ms_lab, col in ms_colors_map.items()]
            legend_handles.append(Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', markersize=8, label='Non-outlier'))
            legend_handles.append(Line2D([0], [0], marker='X', color='w', markerfacecolor='gray', markersize=8, markeredgecolor='red', label='Outlier (LDA-defined)'))

            # Plot LDA decision boundary if model was successfully fit
            if lda_model is not None and scaler_model is not None:
                w = lda_model.coef_[0]
                b = lda_model.intercept_[0]
              
                xx_line = np.linspace(final_xlims[0], final_xlims[1], 100)
                if abs(w[1]) > 1e-6: # Avoid division by zero if w[1] is (close to) zero
                    xx_line_scaled_x = (xx_line - scaler_model.mean_[0]) / scaler_model.scale_[0]
                    yy_line_scaled_y = (-w[0] * xx_line_scaled_x - b) / w[1]
                    yy_line = yy_line_scaled_y * scaler_model.scale_[1] + scaler_model.mean_[1]
                    plt.plot(xx_line, yy_line, 'k-', linewidth=2, label='LDA Decision Boundary')
                    legend_handles.append(Line2D([0], [0], linestyle='-', color='k', label='LDA Boundary'))
                elif abs(w[0]) > 1e-6: # Vertical line: x_scaled = -b / w[0]
                     x_boundary_val_scaled = -b / w[0]
                     x_boundary_val = x_boundary_val_scaled * scaler_model.scale_[0] + scaler_model.mean_[0]
                     plt.axvline(x_boundary_val, color='k', linestyle='-', linewidth=2, label='LDA Decision Boundary (Vertical)')
                     legend_handles.append(Line2D([0], [0], linestyle='-', color='k', label='LDA Boundary'))


            # Optionally, still plot median lines for reference if they were calculated
            if pd.notna(median_scd_for_plot) and pd.notna(median_dcd_for_plot):
                plt.axvline(median_scd_for_plot, color='grey', linestyle=':', linewidth=1, alpha=0.7, label=f'_MedSCD ({median_scd_for_plot:.2f})')
                plt.axhline(median_dcd_for_plot, color='grey', linestyle=':', linewidth=1, alpha=0.7, label=f'_MedDCD ({median_dcd_for_plot:.2f})')
                if not any("Median Info" in h.get_label() for h in legend_handles if h.get_label()):
                     legend_handles.append(Line2D([0], [0], linestyle=':', color='grey', label='Median Info'))


            non_outlier_df = df_condition[~df_condition['Is Outlier']]

            if not non_outlier_df.empty:
                for ms_label_plot, color in ms_colors_map.items():
                    plot_data = non_outlier_df[non_outlier_df['Microstate'] == ms_label_plot]
                    if not plot_data.empty:
                        plt.scatter(plot_data['Same Class Distance'], plot_data['Diff Class Distance'],
                                    c=color, marker='o', s=60, alpha=0.7, zorder=2)
            
            outlier_df = df_condition[df_condition['Is Outlier']]
            if not outlier_df.empty:
                 for ms_label_plot, color in ms_colors_map.items():
                    plot_data = outlier_df[outlier_df['Microstate'] == ms_label_plot]
                    if not plot_data.empty:
                        plt.scatter(plot_data['Same Class Distance'], plot_data['Diff Class Distance'],
                                    facecolors=color, edgecolors='red', marker='X', s=100, linewidths=1.5, zorder=3)
                        for _, row in plot_data.iterrows():
                            x_coord, y_coord = row['Same Class Distance'], row['Diff Class Distance']
                            if np.isnan(x_coord) or np.isnan(y_coord) or np.isinf(x_coord) or np.isinf(y_coord): continue
                            text_x_coordinate = x_coord
                            plt.annotate(
                                f"S{row['Subject ID']}({ms_label_plot})", xy=(x_coord, y_coord),
                                xytext=(text_x_coordinate, text_baseline_y_for_anno), fontsize=14, color='dimgray', zorder=5,
                                ha='center', va='bottom', arrowprops=dict(arrowstyle="simple,head_length=0.3,head_width=0.3,tail_width=0.1", linewidth=0.8, color='darkgray', shrinkA=2, shrinkB=2, connectionstyle="arc3,rad=0"),
                                bbox=dict(facecolor='white', alpha=0.6, edgecolor='lightgray', pad=0.2)
                            )
            
            plt.xlabel('Distance to Same Class Median (SCD)')
            plt.ylabel('Average Distance to Different Class Medians (DCD)')
            plt.title(f'{group} - {condition}\nOutliers defined by LDA')
            plt.grid(True, alpha=0.3)
            
            ordered_legend_handles = []

            main_types = [h for h in legend_handles if "MS " in h.get_label()]
            boundary_types = [h for h in legend_handles if "Boundary" in h.get_label() or "Median Info" in h.get_label()]
            outlier_status = [h for h in legend_handles if "outlier" in h.get_label().lower()]
            

            final_legend_handles = [h for h in main_types + boundary_types + outlier_status if not h.get_label().startswith('_')]

            plt.legend(handles=final_legend_handles, loc='best', fontsize=12)

            try:
                plt.tight_layout(pad=0.3)
            except UserWarning:
                print(f"  Note: Tight layout could not be fully applied for plot {group} - {condition}.")
            plt.savefig(os.path.join(plot_dir, f'{group}_{condition}_ALL_MS_lda_median_outlier_plot.png'), dpi=300, bbox_inches='tight')
            plt.close()
            print(f"  Created LDA-based outlier plot for {group} - {condition}")
            

    if all_data_list:
        all_df = pd.DataFrame(all_data_list)
        csv_filename = f'all_microstate_distance_data_lda_median.csv' # Update filename
        all_df.to_csv(os.path.join(results_dir, csv_filename), index=False)
        summary_data = []
        for group_s in groups:
            if group_s not in all_outliers_dict: continue
            for condition_s in conditions:
                if condition_s not in all_outliers_dict[group_s]: continue
                for ms_label_s in ms_labels_map.values():
                    subset = all_df[(all_df['Group'] == group_s) & (all_df['Condition'] == condition_s) & (all_df['Microstate'] == ms_label_s)]
                    if subset.empty:
                        summary_data.append({'Group': group_s, 'Condition': condition_s, 'Microstate': ms_label_s, 'Total Subjects': 0, 'Outliers': 0, 'Non-outliers': 0, 'Outlier Percentage': 0})
                        continue
                    total = len(subset)
                    outlier_col = subset.get('Is Outlier', pd.Series(False, index=subset.index, dtype=bool))
                    outlier_count = outlier_col.sum() if pd.api.types.is_bool_dtype(outlier_col) else 0
                    non_outlier_count = total - outlier_count
                    summary_data.append({'Group': group_s, 'Condition': condition_s, 'Microstate': ms_label_s, 'Total Subjects': total, 'Outliers': outlier_count, 'Non-outliers': non_outlier_count, 'Outlier Percentage': round(outlier_count/total*100, 1) if total > 0 else 0})
        if summary_data:
            summary_df = pd.DataFrame(summary_data)
            summary_csv_filename = f'outlier_summary_lda.csv' # Update filename
            summary_df.to_csv(os.path.join(results_dir, summary_csv_filename), index=False)
            print(f"Saved LDA-based summary statistics to {os.path.join(results_dir, summary_csv_filename)}")

    print(f"2D microstate outlier assessment with LDA complete.")
    return all_outliers_dict, all_non_outliers_dict

# Call the function to load the median microstates for each subject and detect microstate outliers for each subject
output_base_dir = './median_microstates_and_outliers'
if not os.path.exists(output_base_dir):
    print(f"Error: Base directory for input data '{output_base_dir}' not found.")
    sys.exit(1)

outliers, non_outliers = plot_outliers(output_base_dir)

no_outliers_found_overall = True
if outliers:
    for group_name, conditions_data in outliers.items():
        for condition_name, ms_data in conditions_data.items():
            for ms_label_name, subject_ids in ms_data.items():
                if subject_ids:
                    print(f"Group: {group_name}, Condition: {condition_name}, MS: {ms_label_name}, Outliers: {subject_ids}")
                    no_outliers_found_overall = False
if no_outliers_found_overall:
    print(f"No outliers identified with the LDA-based criteria and current data.")

# Restore calculation parameters if you are using Windows
if sys.platform == 'win32':
    if _original_omp_num_threads_ is None:
        if 'OMP_NUM_THREADS' in os.environ:
            del os.environ['OMP_NUM_THREADS']
    else:
        os.environ['OMP_NUM_THREADS'] = _original_omp_num_threads_

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.stats import ttest_ind
from statsmodels.stats.multitest import multipletests
import matplotlib.colors as mcolors
from tqdm import tqdm
import warnings
import math
import re

# --- GLOBAL CONSTANTS ---
ms_labels = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}

def cohen_d(x, y):
    """
    Calculates Cohen's d for independent samples.
    """
    nx = len(x)
    ny = len(y)
    if nx < 1 or ny < 1:
        return np.nan
    mx = np.mean(x)
    my = np.mean(y)
    sx = np.std(x, ddof=1) if nx > 1 else 0
    sy = np.std(y, ddof=1) if ny > 1 else 0

    if (nx <= 1 and ny <= 1) or (sx == 0 and sy == 0):
        return 0.0 if mx == my else np.inf

    pooled_sd = np.sqrt(((nx - 1) * sx**2 + (ny - 1) * sy**2) / (nx + ny - 2))

    return (mx - my) / pooled_sd if pooled_sd != 0 else (0.0 if mx == my else np.inf)


def draw_head(ax):
    """Draws a standard head schematic."""
    head_radius = 1.0
    circ = plt.Circle((0, 0), radius=head_radius, color='k', fill=False, lw=1.5, clip_on=False, zorder=1)
    ax.add_patch(circ)
    ax.plot([0, 0], [head_radius, head_radius * 1.1], color='k', lw=1.5, clip_on=False, zorder=1)
    ear_y_offset = 0.0; ear_x_pos = head_radius * 0.98
    ear_length_scale = 0.12; ear_angle_scale=0.15
    ax.plot([-ear_x_pos, -ear_x_pos - ear_length_scale*head_radius], [ear_y_offset, ear_y_offset + ear_angle_scale*head_radius], color='k', lw=1.0, clip_on=False, zorder=1)
    ax.plot([ear_x_pos, ear_x_pos + ear_length_scale*head_radius], [ear_y_offset, ear_y_offset + ear_angle_scale*head_radius], color='k', lw=1.0, clip_on=False, zorder=1)
    ax.set_aspect('equal'); ax.set_xlim(-head_radius * 1.3, head_radius * 1.3); ax.set_ylim(-head_radius * 1.3, head_radius * 1.3); ax.axis('off')

def plot_channel_values(
    ax, p2d_standard_coords, values, cmap, vmin, vmax, default_size=150,
    sig_mask=None, size_values=None, size_vmin=0.0, size_vmax=1.0,
    min_marker_size=50, max_marker_size=600):
    """Plots channels as colored circles with optional size mapping."""
    if p2d_standard_coords is None or not isinstance(p2d_standard_coords, np.ndarray) or p2d_standard_coords.ndim != 2 or p2d_standard_coords.shape[1]!=2:
        return None
    if values is None or len(values) != p2d_standard_coords.shape[0]:
        return None

    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    if sig_mask is None or len(sig_mask) != len(values): sig_mask = np.ones(len(values), dtype=bool)

    valid_idx = np.where(~np.isnan(values) & sig_mask)[0]
    if len(valid_idx) == 0: return None

    plot_x = p2d_standard_coords[valid_idx, 0]
    plot_y = p2d_standard_coords[valid_idx, 1]

    scatter_sizes_arr = np.full(len(valid_idx), default_size, dtype=float)
    if size_values is not None and len(size_values) == len(values):
        valid_size_vals_indexed = size_values[valid_idx]
        size_range_val = size_vmax - size_vmin
        if size_range_val <= 1e-9: size_range_val = 1.0
        normalized_s_vals = np.clip((np.nan_to_num(valid_size_vals_indexed, nan=size_vmin) - size_vmin) / size_range_val, 0, 1)
        calculated_s_arr = min_marker_size + (max_marker_size - min_marker_size) * normalized_s_vals
        scatter_sizes_arr = np.maximum(calculated_s_arr, 1)
        if len(scatter_sizes_arr) != len(valid_idx):
            scatter_sizes_arr = np.full(len(valid_idx), default_size, dtype=float)

    try:
        sc = ax.scatter(plot_x, plot_y, c=values[valid_idx], cmap=cmap, norm=norm,
                        edgecolors='k', linewidths=0.5, s=scatter_sizes_arr, marker='o', zorder=3)
        return sc
    except Exception as e_scatter:
        print(f"Error during scatter plot: {e_scatter}")
        return None

def compare_median_microstates_across_conditions(
    ref_path,
    other_paths,
    output_base,
    xyz_file=None,
    exclude_outliers=False,
    outlier_subjects=None, 
    output_suffix=''
    ):
    """
    Compares median microstate maps across different conditions.

    This function loads median microstate maps for a reference condition and
    one or more other conditions, performs channel-wise statistical comparisons
    (independent samples t-tests), applies FDR correction for multiple
    comparisons, and generates visualizations and summary reports.

    Parameters
    ----------
    ref_path : str
        Path to the directory containing median microstate .ep files for the
        reference condition. Files should be named like 'Subject_<ID>_median_microstates.ep'.
    other_paths : dict
        A dictionary where keys are condition names (e.g., 'ConditionA') and
        values are paths to directories containing median microstate .ep files
        for those conditions. File naming convention should be the same as for
        `ref_path`.
    output_base : str
        Base path for the output directory where results (plots, CSVs, reports)
        will be saved. Subdirectories for each comparison pair will be created
        under this base path.
    xyz_file : str, optional
        Path to a Cartool .xyz electrode coordinates file. If provided and valid,
        electrode positions will be used for plotting. If None or the file is
        not found/invalid, plotting will be disabled.
    exclude_outliers : bool, default False
        If True, subjects identified as outliers (based on `outlier_subjects`)
        will be excluded from the statistical comparisons and group averages
        on a microstate-specific basis.
    outlier_subjects : dict, optional
        A nested dictionary specifying outlier subjects for each condition and
        microstate. The structure should be:
        `{'Group1': {'ConditionA': {'MS_A': ['subj1', 'subj2'], 'MS_B': [...]}, ...}, ...}`
        where 'Group1' might be parsed from the condition key, and 'MS_A', 'MS_B'
        are microstate labels (e.g., 'A', 'B', 'C', 'D'). This is only used if
        `exclude_outliers` is True.
    output_suffix : str, default ''
        A suffix to add to the output directory names, filenames (plots, CSVs,
        reports).

    Returns
    -------
    dict
        A nested dictionary containing the results of the statistical comparisons
        for each pair of conditions and each microstate. The structure is:
        `{(condition1, condition2): {ms_index: {'raw_p_values': ..., 'z_mean_diffs': ..., ...}, ...}, ...}`
        where `ms_index` is 0-3 representing microstates A, B, C, D respectively.
        Returns None if no data is successfully loaded or if there are not enough
        conditions for comparison.
    """
    # Set a fixed random seed for reproducibility
    np.random.seed(42)

    # Prepare output directory path based on whether outliers are excluded
    if exclude_outliers:
        actual_suffix = output_suffix if output_suffix else "_ms_excluded"
        output_base_analysis = f"{output_base}{actual_suffix}"
        print(f"\nINFO: Running analysis with MICROSTATE-WISE outlier exclusion. Results: {output_base_analysis}")
    else:
        output_base_analysis = output_base
        print(f"\nINFO: Running analysis with all subjects/microstates. Results: {output_base_analysis}")
    os.makedirs(output_base_analysis, exist_ok=True)

    # Load electrode coordinates for topographic plotting if available
    pos_xyz_coords = None; ch_names_from_xyz_file = None; p2d_normalized_coords = None
    if xyz_file and os.path.exists(xyz_file):
        try:
            # Attempt to read electrode positions from XYZ file
            pos_xyz_coords, ch_names_from_xyz_file = read_cartool_xyz(xyz_file)
            print(f"  Loaded {len(ch_names_from_xyz_file)} electrode positions from {xyz_file}.")
            
            # Validate the electrode data format
            if not (isinstance(pos_xyz_coords, np.ndarray) and pos_xyz_coords.ndim == 2 and pos_xyz_coords.shape[1] == 3 and
                    isinstance(ch_names_from_xyz_file, list) and len(ch_names_from_xyz_file) == pos_xyz_coords.shape[0]):
                raise ValueError("XYZ file data format incorrect.")
        except Exception as e:
            # Disable plotting if electrode file loading fails
            print(f"  Error loading/validating electrode file '{xyz_file}': {e}. Plotting disabled.")
            pos_xyz_coords = None
    else:
        print(f"  Electrode file ({xyz_file}) not provided or not found. Plotting disabled.")

    # Extract condition name from reference path and create a dictionary of all conditions
    try: 
        ref_name_key_main = os.path.basename(ref_path.rstrip(os.sep)).replace('_median', '')
    except Exception: 
        ref_name_key_main = "Reference"
    
    # Combine reference and other conditions into a single dictionary for processing
    conditions_to_load_paths_main = {ref_name_key_main: ref_path}
    conditions_to_load_paths_main.update(other_paths)
    print(f"\nConditions to be processed: {list(conditions_to_load_paths_main.keys())}")

    # Initialize storage for condition data and subject counts
    condition_data_store = {}
    n_channels_global_val = None
    initial_subject_counts_dict = {}

    # Define known group and condition names for parsing condition keys
    known_groups_for_parse = ['HC', 'Young', 'ADNoEp', 'ADEp']
    known_conditions_short_parse = ['Awake', 'N2', 'REM']

    # Load data for each condition
    for cond_key_main, data_folder_path_main in conditions_to_load_paths_main.items():
        print(f"\nLoading data for condition key: {cond_key_main} from path: {data_folder_path_main}")
        initial_subject_counts_dict[cond_key_main] = {'found_files': 0, 'loaded_subjects': 0}
        
        # Attempt to list files in the condition directory
        try: 
            files_in_folder_main = os.listdir(data_folder_path_main)
        except FileNotFoundError: 
            print(f"  Error: Directory not found: {data_folder_path_main}. Skipping.")
            continue
        
        # Filter for subject median microstate files
        subject_ep_files_main = sorted([f for f in files_in_folder_main if f.startswith('Subject_') and f.endswith('_median_microstates.ep')])
        initial_subject_counts_dict[cond_key_main]['found_files'] = len(subject_ep_files_main)
        
        if not subject_ep_files_main: 
            print(f"  No subject median files found for {cond_key_main}. Skipping.")
            continue

        # Storage for subject data in this condition
        current_cond_subj_data_main = []
        
        # Try to parse group and condition from condition key for outlier handling
        parsed_g_main, parsed_c_main = None, None
        key_parts_main = cond_key_main.split('_')
        if len(key_parts_main) >= 2:
            potential_g_main = key_parts_main[0]
            potential_c_short_main = key_parts_main[-1]
            if potential_g_main in known_groups_for_parse and potential_c_short_main in known_conditions_short_parse:
                parsed_g_main, parsed_c_main = potential_g_main, potential_c_short_main
            elif len(key_parts_main) > 1:
                joined_group_candidate_main = "_".join(key_parts_main[:-1])
                if joined_group_candidate_main in known_groups_for_parse and potential_c_short_main in known_conditions_short_parse:
                    parsed_g_main, parsed_c_main = joined_group_candidate_main, potential_c_short_main
        
        if not (parsed_g_main and parsed_c_main):
            print(f"  Warning: Could not parse Group/Condition from key '{cond_key_main}'. Outlier exclusion depends on matching keys in 'outlier_subjects'.")

        # Process each subject's median microstate file
        for ep_filename_main in tqdm(subject_ep_files_main, desc=f"  Loading {cond_key_main}", unit="file", leave=False):
            full_fpath_main = os.path.join(data_folder_path_main, ep_filename_main)
            
            # Extract subject ID from filename
            id_match_main = re.search(r'Subject_(\d+)_median_microstates\.ep', ep_filename_main)
            if not id_match_main: 
                continue
                
            subj_id_main = id_match_main.group(1)
            
            try:
                # Read the microstate map from .ep file
                subj_ms_map_main = ep_read(full_fpath_main)
                
                # Validate microstate map format (should be 4 microstates)
                if not (isinstance(subj_ms_map_main, np.ndarray) and subj_ms_map_main.ndim == 2 and subj_ms_map_main.shape[0] == 4):
                    print(f"  Warn: Invalid map format or not 4 MS for {subj_id_main} in {ep_filename_main}. Skip.")
                    continue
                    
                # Check channel count consistency
                current_n_ch_main = subj_ms_map_main.shape[1]
                if n_channels_global_val is None: 
                    n_channels_global_val = current_n_ch_main
                elif current_n_ch_main != n_channels_global_val: 
                    print(f"  Warn: Chan count mismatch for {subj_id_main} in {ep_filename_main}. Skip.")
                    continue
                
                # Calculate z-score maps for each microstate
                subj_zmap_main = np.zeros_like(subj_ms_map_main)
                for i_ms_main in range(4):
                    row_data_main = subj_ms_map_main[i_ms_main,:]
                    mean_val_main = np.nanmean(row_data_main)
                    std_val_main = np.nanstd(row_data_main)
                    # Avoid division by zero when calculating z-scores
                    subj_zmap_main[i_ms_main,:] = (row_data_main - mean_val_main) / std_val_main if std_val_main > 1e-10 else np.zeros_like(row_data_main)
                
                # Store the subject's data
                current_cond_subj_data_main.append({'id': subj_id_main, 'map': subj_ms_map_main, 'zmap': subj_zmap_main})
            
            except Exception as e_load_main: 
                print(f"  Error processing file {ep_filename_main} for subject {subj_id_main}: {e_load_main}. Skip.")

        # Check if any valid data was loaded for this condition
        if not current_cond_subj_data_main: 
            print(f"  No valid subject data loaded for {cond_key_main}. Skipping.")
            continue
            
        # Update subject count and store condition data
        initial_subject_counts_dict[cond_key_main]['loaded_subjects'] = len(current_cond_subj_data_main)
        print(f"  Loaded data for {len(current_cond_subj_data_main)} subjects for {cond_key_main}.")
        condition_data_store[cond_key_main] = {
            'subject_data': current_cond_subj_data_main,
            'parsed_group': parsed_g_main, 'parsed_condition': parsed_c_main
        }

    # Check if any data was successfully loaded
    if n_channels_global_val is None: 
        print("\nFATAL Error: No data successfully loaded, channel count undetermined.")
        return None
        
    # Set channel count and default channel names
    n_channels_final = n_channels_global_val
    master_channel_names = [f'Ch{i+1}' for i in range(n_channels_final)]

    # Process electrode coordinates for plotting if available
    if pos_xyz_coords is not None:
        # Check if electrode count matches data channel count
        if len(ch_names_from_xyz_file) != n_channels_final:
            print(f"  Warning: Electrode file channel names ({len(ch_names_from_xyz_file)}) count mismatch with data channels ({n_channels_final}). Plotting disabled.")
            p2d_normalized_coords = None
        else:
            # Use electrode names from file instead of default channel names
            master_channel_names = ch_names_from_xyz_file
            try:
                # Convert 3D coordinates to 2D for plotting (using x,y coordinates)
                coords_for_plot_2d_final = np.stack((pos_xyz_coords[:,1], pos_xyz_coords[:,0]), axis=-1)
                
                # Normalize coordinates to fit in plotting area
                max_abs_coord_val = np.max(np.abs(coords_for_plot_2d_final))
                p2d_normalized_coords = coords_for_plot_2d_final / max_abs_coord_val if max_abs_coord_val > 1e-9 else coords_for_plot_2d_final
                print("  Electrode positions processed and normalized for plotting.")
            except Exception as e_pos_final_proc:
                print(f"  Error processing electrode positions for 2D norm: {e_pos_final_proc}. Plotting disabled.")
                p2d_normalized_coords = None

    # Calculate group average maps for each condition, with outlier exclusion if specified
    condition_processed_avg_maps_dict = {}
    print("\nCalculating group averages (MS-wise exclusion if active)...")
    for cond_key_avg_main, data_payload_main in condition_data_store.items():
        # Get parsed group and condition for this condition
        pg_avg = data_payload_main['parsed_group']
        pc_avg = data_payload_main['parsed_condition']
        
        # Initialize empty average maps
        avg_raw_map = np.full((4, n_channels_final), np.nan)
        avg_z_map = np.full((4, n_channels_final), np.nan)
        n_raw_ms = [0]*4
        n_z_ms = [0]*4
        
        # Calculate average for each microstate
        for ms_i_avg in range(4):
            ms_l_avg = ms_labels[ms_i_avg]  # Convert microstate index to label (A, B, C, D)
            valid_r_maps = []
            valid_z_maps = []
            
            # Process each subject, checking for outliers
            for s_dat_avg in data_payload_main['subject_data']:
                is_o_avg = False
                
                # Check if subject is an outlier for this microstate
                if exclude_outliers and outlier_subjects and pg_avg and pc_avg and \
                   pg_avg in outlier_subjects and pc_avg in outlier_subjects[pg_avg] and \
                   ms_l_avg in outlier_subjects[pg_avg][pc_avg]:
                    if s_dat_avg['id'] in outlier_subjects[pg_avg][pc_avg][ms_l_avg]: 
                        is_o_avg = True
                
                # If not an outlier, include in average calculation
                if not is_o_avg:
                    valid_r_maps.append(s_dat_avg['map'][ms_i_avg, :])
                    valid_z_maps.append(s_dat_avg['zmap'][ms_i_avg, :])
            
            # Calculate averages if there are valid maps
            if valid_r_maps: 
                avg_raw_map[ms_i_avg,:] = np.mean(valid_r_maps, axis=0)
                n_raw_ms[ms_i_avg] = len(valid_r_maps)
            if valid_z_maps: 
                avg_z_map[ms_i_avg,:] = np.mean(valid_z_maps, axis=0)
                n_z_ms[ms_i_avg] = len(valid_z_maps)
        
        # Store average maps and subject counts
        condition_processed_avg_maps_dict[cond_key_avg_main] = {
            'raw_average': avg_raw_map, 'z_average': avg_z_map,
            'n_per_ms_raw': n_raw_ms, 'n_per_ms_z': n_z_ms
        }

    # Check that reference condition data exists and there are enough conditions for comparison
    if ref_name_key_main not in condition_data_store or len(condition_data_store) < 2:
        print(f"\nError: Ref condition '{ref_name_key_main}' data missing or not enough conditions. Cannot proceed.")
        return None

    # Create comparison pairs: 
    # 1. Reference vs each other condition
    # 2. All pairwise comparisons between other conditions
    comparison_pairs_final_list = []
    available_conds_comp_final = list(condition_data_store.keys())
    other_conds_comp_final = [c for c in available_conds_comp_final if c != ref_name_key_main]
    
    # Add reference vs each other condition
    for oc_final in other_conds_comp_final: 
        comparison_pairs_final_list.append((ref_name_key_main, oc_final))
    
    # Add pairwise comparisons between other conditions
    for i_oc1_final in range(len(other_conds_comp_final)):
        for i_oc2_final in range(i_oc1_final + 1, len(other_conds_comp_final)):
            comparison_pairs_final_list.append((other_conds_comp_final[i_oc1_final], other_conds_comp_final[i_oc2_final]))
    
    if not comparison_pairs_final_list: 
        print("\nError: No valid comparison pairs formed.")
        return None
        
    print(f"\nComparison pairs for stats: {comparison_pairs_final_list}")

    # Helper function to get microstate data for statistical testing
    def get_data_for_stat_testing_internal(target_key_stat, ms_idx_for_stat, map_type_for_stat):
        """
        Extract data for a specific microstate and condition, applying outlier exclusion if needed.
        
        Parameters:
        - target_key_stat: Condition key to retrieve data for
        - ms_idx_for_stat: Microstate index (0-3)
        - map_type_for_stat: Map type ('map' for raw or 'zmap' for z-score)
        
        Returns:
        - Array of maps (subjects x channels) for the specified microstate
        """
        ms_lab_for_stat = ms_labels[ms_idx_for_stat]
        lookup_info_stat = condition_data_store.get(target_key_stat)
        if not lookup_info_stat: 
            return np.array([])
            
        pg_for_stat, pc_for_stat = lookup_info_stat['parsed_group'], lookup_info_stat['parsed_condition']
        maps_for_stat_list = []
        
        for s_info_for_stat in lookup_info_stat['subject_data']:
            s_id_for_stat, is_o_for_stat = s_info_for_stat['id'], False
            
            # Check if subject is an outlier for this microstate
            if exclude_outliers and outlier_subjects and pg_for_stat and pc_for_stat and \
               pg_for_stat in outlier_subjects and pc_for_stat in outlier_subjects[pg_for_stat] and \
               ms_lab_for_stat in outlier_subjects[pg_for_stat][pc_for_stat]:
                if s_id_for_stat in outlier_subjects[pg_for_stat][pc_for_stat][ms_lab_for_stat]: 
                    is_o_for_stat = True
            
            # If not an outlier, include in statistical testing
            if not is_o_for_stat:
                if map_type_for_stat in s_info_for_stat and ms_idx_for_stat < s_info_for_stat[map_type_for_stat].shape[0]:
                    maps_for_stat_list.append(s_info_for_stat[map_type_for_stat][ms_idx_for_stat, :])
        
        return np.array(maps_for_stat_list)

    # Perform statistical comparisons between condition pairs
    all_comparison_results_dict = {}
    print(f"\nRunning channel-wise stats (MS-wise exclusion: {exclude_outliers})...")
    for c1_stat_main, c2_stat_main in tqdm(comparison_pairs_final_list, desc="Comparing Pairs", unit="pair"):
        all_comparison_results_dict[(c1_stat_main, c2_stat_main)] = {}
        
        # For each microstate (A, B, C, D)
        for ms_idx_stat_main in range(4):
            # Get raw and z-score data for both conditions
            m1r_stat = get_data_for_stat_testing_internal(c1_stat_main, ms_idx_stat_main, 'map')
            m2r_stat = get_data_for_stat_testing_internal(c2_stat_main, ms_idx_stat_main, 'map')
            m1z_stat = get_data_for_stat_testing_internal(c1_stat_main, ms_idx_stat_main, 'zmap')
            m2z_stat = get_data_for_stat_testing_internal(c2_stat_main, ms_idx_stat_main, 'zmap')
            
            # Store subject counts for this microstate comparison
            current_ms_res_dict = {'n_raw':(m1r_stat.shape[0],m2r_stat.shape[0]), 'n_z':(m1z_stat.shape[0],m2z_stat.shape[0])}
            
            # Process both raw and z-score data
            for d_type_stat, d1_stat_arr, d2_stat_arr, n1_stat_val, n2_stat_val, pfx_stat in [
                ('raw', m1r_stat, m2r_stat, m1r_stat.shape[0], m2r_stat.shape[0], 'raw_'),
                ('z',   m1z_stat, m2z_stat, m1z_stat.shape[0], m2z_stat.shape[0], 'z_')]:
                
                # Initialize arrays for statistical results
                p_stat_arr = np.full(n_channels_final, np.nan)
                md_stat_arr = np.full(n_channels_final, np.nan)
                t_stat_arr = np.full(n_channels_final, np.nan)
                d_stat_arr = np.full(n_channels_final, np.nan)
                
                # Perform channel-wise statistical tests if enough subjects in both groups
                if n1_stat_val >= 2 and n2_stat_val >= 2:
                    for ch_i_stat in range(n_channels_final):
                        # Get channel data for both conditions
                        d1ch_stat, d2ch_stat = d1_stat_arr[:, ch_i_stat], d2_stat_arr[:, ch_i_stat]
                        
                        # Check for valid variance in both groups
                        var1_ok_stat = np.nanvar(d1ch_stat) > 1e-10
                        var2_ok_stat = np.nanvar(d2ch_stat) > 1e-10
                        
                        if var1_ok_stat and var2_ok_stat:
                            try:
                                # Perform t-test and calculate effect size (Cohen's d)
                                t_s_val, p_s_val = ttest_ind(d1ch_stat, d2ch_stat, equal_var=False, nan_policy='omit')
                                c_d_val = cohen_d(d1ch_stat[~np.isnan(d1ch_stat)], d2ch_stat[~np.isnan(d2ch_stat)])
                                
                                # Store results
                                md_stat_arr[ch_i_stat] = np.nanmean(d1ch_stat) - np.nanmean(d2ch_stat)
                                t_stat_arr[ch_i_stat] = t_s_val
                                p_stat_arr[ch_i_stat] = p_s_val
                                d_stat_arr[ch_i_stat] = c_d_val
                            except Exception as e_stat_calc_loop: 
                                print(f" Stat Err({pfx_stat}MS{ms_idx_stat_main},Ch{ch_i_stat}):{e_stat_calc_loop}")
                        elif n1_stat_val > 0 and n2_stat_val > 0: 
                            # If variance check fails but both groups have subjects, just calculate mean difference
                            md_stat_arr[ch_i_stat] = np.nanmean(d1ch_stat) - np.nanmean(d2ch_stat)
                elif n1_stat_val > 0 and n2_stat_val > 0:
                    # If not enough subjects for t-test, just calculate mean differences
                    for ch_i_stat in range(n_channels_final): 
                        md_stat_arr[ch_i_stat] = np.nanmean(d1_stat_arr[:, ch_i_stat]) - np.nanmean(d2_stat_arr[:, ch_i_stat])
                
                # Store results in dictionary
                current_ms_res_dict[f'{pfx_stat}p_values'] = p_stat_arr
                current_ms_res_dict[f'{pfx_stat}mean_diffs'] = md_stat_arr
                current_ms_res_dict[f'{pfx_stat}t_scores'] = t_stat_arr
                current_ms_res_dict[f'{pfx_stat}cohen_d'] = d_stat_arr
                current_ms_res_dict[f'{pfx_stat}significant_channel'] = np.zeros(n_channels_final, dtype=bool)
                current_ms_res_dict[f'{pfx_stat}n_significant_channel'] = 0
                
            # Store results for this microstate
            all_comparison_results_dict[(c1_stat_main, c2_stat_main)][ms_idx_stat_main] = current_ms_res_dict

    # Apply FDR (False Discovery Rate) correction for multiple comparisons
    print(f"\nApplying FDR correction...")
    if all_comparison_results_dict:
        for pair_key_fdr_main, ms_results_fdr_main in all_comparison_results_dict.items():
            for ms_idx_fdr_main, res_data_fdr_main in ms_results_fdr_main.items():
                for prefix_fdr_main in ['raw_', 'z_']:
                    # Get p-values for this microstate and map type
                    p_values_fdr_main = res_data_fdr_main.get(f'{prefix_fdr_main}p_values', [])
                    
                    # Find valid (non-NaN) p-values
                    valid_p_indices_fdr = ~np.isnan(p_values_fdr_main)
                    p_values_to_correct_fdr = p_values_fdr_main[valid_p_indices_fdr]
                    
                    if len(p_values_to_correct_fdr) > 0:
                        # Apply Benjamini-Hochberg FDR correction
                        reject_fdr_arr, _, _, _ = multipletests(p_values_to_correct_fdr, 0.05, 'fdr_bh')
                        
                        # Store which channels are significant after correction
                        res_data_fdr_main[f'{prefix_fdr_main}significant_channel'][valid_p_indices_fdr] = reject_fdr_arr
                        res_data_fdr_main[f'{prefix_fdr_main}n_significant_channel'] = np.sum(reject_fdr_arr)
    else:
        print("No comparison results to apply FDR correction to.")

    # Generate visualizations if electrode coordinates are available
    if p2d_normalized_coords is not None:
        print("\nGenerating visualizations...")
        ordered_ms_indices_for_plot = [0, 1, 2, 3]  # A, B, C, D
        
        # For each comparison pair
        for (c1_p, c2_p), results_for_pair_p in all_comparison_results_dict.items():
            print(f"  Generating plots for {c1_p} vs {c2_p}...")
            
            # Create output directory for this comparison
            output_dir_for_pair_p = os.path.join(output_base_analysis, f'{c1_p}_vs_{c2_p}')
            os.makedirs(output_dir_for_pair_p, exist_ok=True)
            
            # Set global plotting parameters
            glob_t_max_p = 5.0
            glob_t_min_p = -glob_t_max_p
            glob_d_min_s_p = 0.0
            glob_d_max_s_p = 1.0
            min_m_area_p = 50
            max_m_area_p = 600
            
            # Generate plots for both raw potentials and z-scores
            for plot_var, data_pfx_p, fig_title_sfx_p in [
                ('potential', 'raw_', ""), ('zscore', 'z_', " (Z-Score Data)")]:
                
                # Create figure with 4 subplots (one for each microstate)
                fig_p_main = plt.figure(figsize=(16, 5.5))
                gs_p_main = gridspec.GridSpec(1, 4, figure=fig_p_main, hspace=0.3, wspace=0.1, top=0.88, bottom=0.12, left=0.05, right=0.80)
                sc_handles_p = {}
                
                # Plot each microstate
                for i_plot_c, ms_idx_p_ord in enumerate(ordered_ms_indices_for_plot):
                    if ms_idx_p_ord not in results_for_pair_p: 
                        continue
                    
                    # Get results for this microstate
                    ms_res_p = results_for_pair_p[ms_idx_p_ord]
                    
                    # Create subplot and draw head schematic
                    ax_p = fig_p_main.add_subplot(gs_p_main[0, i_plot_c])
                    draw_head(ax_p)
                    
                    # Get t-scores, p-values and effect sizes for plotting
                    t_sc_p = ms_res_p[f'{data_pfx_p}t_scores']
                    p_val_p_mask = ms_res_p[f'{data_pfx_p}p_values']
                    coh_d_p_size = np.abs(ms_res_p[f'{data_pfx_p}cohen_d'])
                    
                    # Create mask for significant channels (p < 0.05)
                    plot_s_mask = (~np.isnan(p_val_p_mask)) & (p_val_p_mask < 0.05)
                    
                    # Plot channel values: t-scores as colors, effect sizes as circle sizes
                    sc_h_p_val = plot_channel_values(
                        ax_p, p2d_normalized_coords, t_sc_p, 'RdBu_r', glob_t_min_p, glob_t_max_p,
                        sig_mask=plot_s_mask, size_values=coh_d_p_size, size_vmin=glob_d_min_s_p, size_vmax=glob_d_max_s_p,
                        min_marker_size=min_m_area_p, max_marker_size=max_m_area_p
                    )
                    
                    # Set microstate label as subplot title
                    ax_p.set_title(ms_labels[ms_idx_p_ord], fontsize=14)
                    
                    # Store scatter handle for colorbar
                    if i_plot_c == 0 and sc_h_p_val is not None:
                        sc_handles_p['t_stat'] = sc_h_p_val
                
                # Add colorbar for t-statistics
                if 't_stat' in sc_handles_p:
                    cax_t_p_main = fig_p_main.add_axes([0.83, 0.15, 0.015, 0.7])
                    cb_t_p_main = fig_p_main.colorbar(sc_handles_p['t_stat'], cax=cax_t_p_main)
                    cb_t_p_main.set_label('T-Statistic', size=14)
                
                # Create a separate axis for the continuous cone-shaped legend
                s_range_leg_p = glob_d_max_s_p - glob_d_min_s_p
                s_range_leg_p = 1.0 if s_range_leg_p <= 1e-9 else s_range_leg_p
                
                # Create a separate axis for the continuous cone-shaped legend
                cone_ax = fig_p_main.add_axes([0.9, 0.2, 0.08, 0.6])  # Positioned further right
                cone_ax.set_xlim(0, 1)
                cone_ax.set_ylim(0, 1)
                
                # Create a continuous cone shape
                num_points = 100  # Number of points for smoother cone
                y_vals = np.linspace(0, 1, num_points)
                x_center = 0.5
                max_radius = 0.3  # Cone width at base
                
                # Calculate d values along the cone
                d_vals = np.linspace(glob_d_min_s_p, glob_d_max_s_p, num_points)
                
                # Calculate marker sizes that would be used in the actual plot
                norm_d_vals = np.clip((d_vals - glob_d_min_s_p) / s_range_leg_p, 0, 1)
                marker_sizes = min_m_area_p + (max_m_area_p - min_m_area_p) * norm_d_vals
                
                # Scale marker sizes for the legend
                scaled_radii = np.sqrt(marker_sizes / np.pi) / 100  # Scale to fit in the axis
                
                # Draw the cone shape points (circles)
                for i, (y, r) in enumerate(zip(y_vals, scaled_radii)):
                    circle = plt.Circle((x_center, y), r, facecolor='grey', edgecolor='k', lw=0.5)
                    cone_ax.add_patch(circle)
                
                # Set tick and label offsets
                tick_offset = 0.15  # Increased offset
                label_offset = 0.25  # Increased offset
                
                # Draw vertical axis line
                cone_ax.plot([x_center, x_center], [0, 1], 'k-', lw=0.75)
                
                # Add markers and labels for specific d values
                d_markers = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]  # More markers for better reference
                for d in d_markers:
                    # Find the position on y-axis corresponding to this d value
                    y_pos = (d - glob_d_min_s_p) / (glob_d_max_s_p - glob_d_min_s_p) if glob_d_max_s_p != glob_d_min_s_p else 0.5
                    if 0 <= y_pos <= 1:  # Ensure it's within the axis bounds
                        # Draw tick mark with offset
                        cone_ax.plot([x_center - tick_offset, x_center], [y_pos, y_pos], 'k-', lw=0.5)
                        # Label with offset
                        cone_ax.text(x_center - label_offset, y_pos, f"{d:.1f}", ha='right', va='center', fontsize=14)
                
                cone_ax.set_title("abs(Cohen's d)", fontsize=14, pad=10)
                cone_ax.axis('off')  # Hide the axis lines but keep the content
                
                # Set figure title and save
                fig_p_main.suptitle(f'{c1_p} vs {c2_p}{fig_title_sfx_p}', fontsize=16, y=0.96)
                plot_fname_p = os.path.join(output_dir_for_pair_p, f'comparison_{plot_var}_fdr_tTest_cohenDsize{output_suffix}.png')
                fig_p_main.savefig(plot_fname_p, dpi=300, bbox_inches='tight')
                plt.close(fig_p_main)
            
            print(f"    Visualizations saved for {c1_p} vs {c2_p}.")
    else:
        print("\nSkipping visualizations: Normalized 2D electrode coordinates not available.")

    # Save numerical results to CSV files
    print("\nSaving numerical results (CSVs)...")
    try:
        # For each comparison pair
        for (c1_csv_main, c2_csv_main), res_pair_data_csv_main in all_comparison_results_dict.items():
            # Create output directory for this comparison
            out_dir_csv_p_main = os.path.join(output_base_analysis, f'{c1_csv_main}_vs_{c2_csv_main}')
            os.makedirs(out_dir_csv_p_main, exist_ok=True)
            
            # For each microstate in this comparison
            for ms_idx_val_csv_main, ms_res_data_csv_main in res_pair_data_csv_main.items():
                ms_l_csv_main = ms_labels.get(ms_idx_val_csv_main, f'MS{ms_idx_val_csv_main}')
                
                # Save both raw and z-score results
                for pfx_csv_main in ['raw_', 'z_']:
                    # Create DataFrame with all statistical results
                    df_cols_csv_main = {
                        'Channel': master_channel_names,
                        f'MeanDiff_uV' if pfx_csv_main == 'raw_' else 'MeanDiff_Z': ms_res_data_csv_main[f'{pfx_csv_main}mean_diffs'],
                        'T_Statistic': ms_res_data_csv_main[f'{pfx_csv_main}t_scores'], 
                        'Cohen_d': ms_res_data_csv_main[f'{pfx_csv_main}cohen_d'],
                        'p_Value_uncorrected': ms_res_data_csv_main[f'{pfx_csv_main}p_values'],
                        'Significant_FDR_BH': ms_res_data_csv_main[f'{pfx_csv_main}significant_channel']
                    }
                    df_out_csv_main = pd.DataFrame(df_cols_csv_main)
                    
                    # Save to CSV file
                    csv_fpath_main = os.path.join(
                        out_dir_csv_p_main, 
                        f'MS{ms_l_csv_main}_{pfx_csv_main[:-1]}_results_{c1_csv_main}_vs_{c2_csv_main}_fdr_per_ms{output_suffix}.csv'
                    )
                    df_out_csv_main.to_csv(csv_fpath_main, index=False, float_format='%.6f', na_rep='NaN')
        
        print("  CSV files saved.")
    except Exception as e_save_csv_main:
        print(f"  Error saving CSV files: {e_save_csv_main}")

    # Generate and save summary report
    print("\nGenerating summary report...")
    report_fpath_main = os.path.join(output_base_analysis, f'summary_report_fdr_per_ms{output_suffix}.txt')
    try:
        with open(report_fpath_main, 'w', encoding='utf-8') as fr_main:
            # Write report header
            fr_main.write("MICROSTATE COMPARISON SUMMARY REPORT\n" + "="*40 + "\n\n")
            fr_main.write(f"Analysis Output Dir: {os.path.abspath(output_base_analysis)}\nRef Cond: {ref_name_key_main}\n")
            fr_main.write(f"All Processed Conds: {list(condition_data_store.keys())}\nComparison Pairs: {comparison_pairs_final_list}\n")
            fr_main.write(f"Outlier Exclusion: {exclude_outliers} (MS-Wise if True)\n\nInitial Subject Counts (before MS-specific exclusion for stats):\n")
            
            # Write subject counts for each condition
            for ck_rpi, cts_rpi in initial_subject_counts_dict.items():
                fr_main.write(f"  - {ck_rpi}: Files Found={cts_rpi['found_files']}, Subjects Loaded={cts_rpi['loaded_subjects']}\n")
            
            # Write FDR information
            fr_main.write("\nFDR (Benjamini/Hochberg, alpha=0.05) applied PER MICROSTATE across channels.\n" + "-"*75 + "\n\n")
            
            if not all_comparison_results_dict:
                fr_main.write("No comparison results generated.\n")
            else:
                # For each comparison pair
                for (c1_rm, c2_rm), data_for_rp_pair_m in all_comparison_results_dict.items():
                    # Write comparison header
                    fr_main.write(f"Comparison: {c1_rm} vs {c2_rm}\n" + "="*len(f"Comparison: {c1_rm} vs {c2_rm}") + "\n\n")
                    fr_main.write("Sig. Channels (FDR p<0.05) & N subjects per MS in T-test:\n")
                    fr_main.write("---------------------------------------------------------------------------\n")
                    fr_main.write("| Microstate | N (Raw: c1/c2) | N (Z-Sc: c1/c2) | Sig. Ch Raw | Sig. Ch Z |\n")
                    fr_main.write("|------------|----------------|-----------------|-------------|-----------|\n")
                    
                    any_s_p_m = False
                    # For each microstate in this comparison
                    for ms_idx_rm in range(4):
                        ms_l_rm = ms_labels.get(ms_idx_rm, f'MS{ms_idx_rm}')
                        ms_res_rm = data_for_rp_pair_m.get(ms_idx_rm, {})
                        
                        # Get subject counts
                        n_r_v_rm = ms_res_rm.get('n_raw', ('?', '?'))
                        n_z_v_rm = ms_res_rm.get('n_z', ('?', '?'))
                        n_r_s_rm = f"{n_r_v_rm[0]}/{n_r_v_rm[1]}"
                        n_z_s_rm = f"{n_z_v_rm[0]}/{n_z_v_rm[1]}"
                        
                        # Get significant channel counts
                        n_s_r_val = ms_res_rm.get('raw_n_significant_channel', 0)
                        n_s_z_val = ms_res_rm.get('z_n_significant_channel', 0)
                        
                        # Write table row
                        fr_main.write(f"| {ms_l_rm:<10} | {n_r_s_rm:<14} | {n_z_s_rm:<15} | {n_s_r_val:<11} | {n_s_z_val:<9} |\n")
                        
                        if n_s_r_val > 0 or n_s_z_val > 0:
                            any_s_p_m = True
                    
                    fr_main.write("---------------------------------------------------------------------------\n")
                    if not any_s_p_m:
                        fr_main.write("  (No FDR significant channels for this pair)\n")
                    fr_main.write("\n" + "-"*40 + "\n\n")
        
        print(f"Summary report saved to: {report_fpath_main}")
    except Exception as e_write_rep_main:
        print(f"Error generating summary report: {e_write_rep_main}")

    print(f"\nAnalysis complete. Results in: {os.path.abspath(output_base_analysis)}")
    return all_comparison_results_dict

# Function calls to run the topographic analyses with and without outliers
## Process the older HC group comparing topographies between sleep stages (e.g., N2 vs REM)

if __name__ == "__main__":
    
    # Function call with outlier exclusion
    results_no_outliers = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/HC_Awake_median',
        other_paths={
            'HC_N2': './median_microstates_and_outliers/HC_N2_median',
            'HC_REM': './median_microstates_and_outliers/HC_REM_median'
        },
        output_base='./microstate_analysis_output_newplots4/comparison_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=True,
        outlier_subjects=outliers,  # Now passing outlier_subjects instead of non_outlier_subjects
        output_suffix='_no_outliers'
    )
    
    # Function call with all subjects (no outlier exclusion)
    results_all = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/HC_Awake_median',
        other_paths={
            'HC_N2': './median_microstates_and_outliers/HC_N2_median',
            'HC_REM': './median_microstates_and_outliers/HC_REM_median'
        },
        output_base='./microstate_analysis_output_newplots4/comparison_older_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=False,
        outlier_subjects=None,
        output_suffix=''
    ) 

In [None]:
## Process the older HC group comparing topographies between sleep stages (e.g., N2 vs REM)

if __name__ == "__main__":
    
    # Function call with outlier exclusion
    results_no_outliers = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/Young_Awake_median',
        other_paths={
            'HC_N2': './median_microstates_and_outliers/Young_N2_median',
            'HC_REM': './median_microstates_and_outliers/Young_REM_median'
        },
        output_base='./microstate_analysis_output_newplots4/comparison_younger_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=True,
        outlier_subjects=outliers,  # Now passing outlier_subjects instead of non_outlier_subjects
        output_suffix='_no_outliers'
    )
    
    # Function call with all subjects (no outlier exclusion)
    results_all = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/Young_Awake_median',
        other_paths={
            'HC_N2': './median_microstates_and_outliers/Young_N2_median',
            'HC_REM': './median_microstates_and_outliers/Young_REM_median'
        },
        output_base='./microstate_analysis_output_newplots4/comparison_younger_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=False,
        outlier_subjects=None,
        output_suffix=''
    ) 

In [None]:
## Compare the older HC group vs Younger HC group across sleep stages

if __name__ == "__main__":
    
    # Function call with outlier exclusion
    results_no_outliers = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/HC_Awake_median',
        other_paths={
            'Young_Awake': './median_microstates_and_outliers/Young_Awake_median'
        },
        output_base='./microstate_analysis_output_newplots4/Old_vs_young',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=True,
        outlier_subjects=outliers,  # Now passing outlier_subjects instead of non_outlier_subjects
        output_suffix='_no_outliers'
    )
    
    # Function call with all subjects (no outlier exclusion)
    results_all = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/HC_Awake_median',
        other_paths={
            'Young_Awake': './median_microstates_and_outliers/Young_Awake_median'
        },
        output_base='./microstate_analysis_output_newplots4/Old_vs_young',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=False,
        outlier_subjects=None,
        output_suffix=''
    ) 


if __name__ == "__main__":
    
    # Function call with outlier exclusion
    results_no_outliers = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/HC_N2_median',
        other_paths={
            'Young_N2': './median_microstates_and_outliers/Young_N2_median'
        },
        output_base='./microstate_analysis_output_newplots4/Old_vs_young',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=True,
        outlier_subjects=outliers,  # Now passing outlier_subjects instead of non_outlier_subjects
        output_suffix='_no_outliers'
    )
    
    # Function call with all subjects (no outlier exclusion)
    results_all = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/HC_N2_median',
        other_paths={
            'Young_N2': './median_microstates_and_outliers/Young_N2_median'
        },
        output_base='./microstate_analysis_output_newplots4/Old_vs_young',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=False,
        outlier_subjects=None,
        output_suffix=''
    ) 

if __name__ == "__main__":
    
    # Function call with outlier exclusion
    results_no_outliers = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/HC_REM_median',
        other_paths={
            'Young_REM': './median_microstates_and_outliers/Young_REM_median'
        },
        output_base='./microstate_analysis_output_newplots4/Old_vs_young',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=True,
        outlier_subjects=outliers,  # Now passing outlier_subjects instead of non_outlier_subjects
        output_suffix='_no_outliers'
    )
    
    # Function call with all subjects (no outlier exclusion)
    results_all = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/HC_REM_median',
        other_paths={
            'Young_REM': './median_microstates_and_outliers/Young_REM_median'
        },
        output_base='./microstate_analysis_output_newplots4/Old_vs_young',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=False,
        outlier_subjects=None,
        output_suffix=''
    ) 

In [None]:
## Compare pathological groups

if __name__ == "__main__":
    
    # Function call with outlier exclusion
    results_no_outliers = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/ADNoEp_Awake_median',
        other_paths={
            'ADEp_Awake': './median_microstates_and_outliers/ADEp_Awake_median',
            'HC_Awake': './median_microstates_and_outliers/HC_Awake_median'
        },
        output_base='./microstate_analysis_output_newplots4/pathological_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=True,
        outlier_subjects=outliers,  # Now passing outlier_subjects instead of non_outlier_subjects
        output_suffix='_no_outliers'
    )
    
    # Function call with all subjects (no outlier exclusion)
    results_all = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/ADNoEp_Awake_median',
        other_paths={
            'ADEp_Awake': './median_microstates_and_outliers/ADEp_Awake_median',
            'HC_Awake': './median_microstates_and_outliers/HC_Awake_median'
        },
        output_base='./microstate_analysis_output_newplots4/pathological_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=False,
        outlier_subjects=None,
        output_suffix=''
    ) 

if __name__ == "__main__":
    
    # Function call with outlier exclusion
    results_no_outliers = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/ADNoEp_N2_median',
        other_paths={
            'ADEp_N2': './median_microstates_and_outliers/ADEp_N2_median',
            'HC_N2': './median_microstates_and_outliers/HC_N2_median'
        },
        output_base='./microstate_analysis_output_newplots4/pathological_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=True,
        outlier_subjects=outliers,  # Now passing outlier_subjects instead of non_outlier_subjects
        output_suffix='_no_outliers'
    )
    
    # Function call with all subjects (no outlier exclusion)
    results_all = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/ADNoEp_N2_median',
        other_paths={
            'ADEp_N2': './median_microstates_and_outliers/ADEp_N2_median',
            'HC_N2': './median_microstates_and_outliers/HC_N2_median'
        },
        output_base='./microstate_analysis_output_newplots4/pathological_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=False,
        outlier_subjects=None,
        output_suffix=''
    ) 

if __name__ == "__main__":
    
    # Function call with outlier exclusion
    results_no_outliers = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/ADNoEp_REM_median',
        other_paths={
            'ADEp_REM': './median_microstates_and_outliers/ADEp_REM_median',
            'HC_REM': './median_microstates_and_outliers/HC_REM_median'
        },
        output_base='./microstate_analysis_output_newplots4/pathological_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=True,
        outlier_subjects=outliers,  # Now passing outlier_subjects instead of non_outlier_subjects
        output_suffix='_no_outliers'
    )
    
    # Function call with all subjects (no outlier exclusion)
    results_all = compare_median_microstates_across_conditions(
        ref_path='./median_microstates_and_outliers/ADNoEp_REM_median',
        other_paths={
            'ADEp_REM': './median_microstates_and_outliers/ADEp_REM_median',
            'HC_REM': './median_microstates_and_outliers/HC_REM_median'
        },
        output_base='./microstate_analysis_output_newplots4/pathological_results',
        xyz_file='Electrodes 10-20_v3.20.xyz',
        exclude_outliers=False,
        outlier_subjects=None,
        output_suffix=''
    ) 