In [None]:
# ### 1. Import Necessary Libraries

#%%
import mne
from mne_connectivity import spectral_connectivity_epochs
import numpy as np
import networkx as nx
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score, learning_curve
from sklearn.metrics import classification_report, roc_curve, auc, RocCurveDisplay
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
import os

In [None]:


# --- Configuration ---
# 1. Define your base EEG data folder
#    This folder now directly contains your .fif files
base_eeg_folder = "C:\\Users\\LENOVO\\OneDrive - City University\\Desktop\\Attention detection FYP\\Data"

# 2. Define the LIST of subject IDs you want to load data for
subject_ids = ['S02','S03','S04', 'S05','S06','S07','S08','S09','S10','S11'] # Add more subjects as needed, e.g., ['S01', 'S02', 'S03', ...]
# artifact_subject_ids = ['V01', 'V02', 'V03', 'V04', 'V05', 'V06'] # List of subjects with artifacts
# --- Dictionary to store raw data and epochs for each subject ---
raw_data_by_subject = {}
epochs_by_subject = {}

# Select only the specified channels (this remains constant for all subjects)
channels = ['Fp1', 'Fp2', 'Fz', 'F3', 'F4', 'C3', 'Cz', 'C4']

# --- Loop through each subject ID to load and preprocess their data ---
print(f"--- Starting data loading for subjects: {subject_ids} ---")
for current_subject_id in subject_ids:
    print(f"\nProcessing data for {current_subject_id}...")

    # --- CORRECTED PATH CONSTRUCTION ---
    # Now, we directly join base_eeg_folder with the filename,
    # as there's no subfolder for the subject ID.
    attention_file = os.path.join(base_eeg_folder, f"Att_{current_subject_id}_cleaned_raw.fif")
    inattention_file = os.path.join(base_eeg_folder, f"Inatt_{current_subject_id}_cleaned_raw.fif")
    # IMPORTANT: Double-check 'Inatt' capitalization for S05 if needed, you wrote 'inatt_S05...'
    # If the file is actually 'inatt_S05_cleaned_raw3.fif' (lowercase 'i'), then change the line above to:
    # inattention_file = os.path.join(base_eeg_folder, f"inatt_{current_subject_id}_cleaned_raw3.fif")


    # --- Verify paths (optional, but good for debugging) ---
    print(f"  Attempting to load Attention file: {attention_file}")
    print(f"  Attempting to load Inattention file: {inattention_file}")

    # --- Load the raw FIF files ---
    try:
        raw_att = mne.io.read_raw_fif(attention_file, preload=True, verbose=False)
        raw_inatt = mne.io.read_raw_fif(inattention_file, preload=True, verbose=False)
        print(f"  Raw Attention data shape: {raw_att.get_data().shape} (channels x time points)")
        print(f"  Raw Inattention data shape: {raw_inatt.get_data().shape} (channels x time points)")
    except FileNotFoundError as e:
        print(f"  ERROR: File not found for {current_subject_id}. Skipping this subject. Details: {e}")
        continue # Skip to the next subject if files are missing

    # Select only the specified channels
    raw_att.pick_channels(channels, verbose=False)
    raw_inatt.pick_channels(channels, verbose=False)
    print(f"  Raw Attention data shape after channel selection: {raw_att.get_data().shape}")
    print(f"  Raw Inattention data shape after channel selection: {raw_inatt.get_data().shape}")

    # Store raw data objects
    raw_data_by_subject[current_subject_id] = {'raw_att': raw_att, 'raw_inatt': raw_inatt}

    # --- Epoching ---
    epochs_att = mne.make_fixed_length_epochs(raw_att, duration=10.0, overlap=5, preload=True, verbose=False)
    epochs_inatt = mne.make_fixed_length_epochs(raw_inatt, duration=10.0, overlap=5, preload=True, verbose=False)

    print(f"  {current_subject_id} Attention epochs: {len(epochs_att)} epochs, data shape: {epochs_att.get_data().shape} (epochs x channels x time points)")
    print(f"  {current_subject_id} Inattention epochs: {len(epochs_inatt)} epochs, data shape: {epochs_inatt.get_data().shape} (epochs x channels x time points)")

    # Store epoch objects
    epochs_by_subject[current_subject_id] = {'epochs_att': epochs_att, 'epochs_inatt': epochs_inatt}

print("\n--- Finished loading and epoching data for all specified subjects ---")

In [None]:

import numpy as np
import mne
from mne_connectivity import spectral_connectivity_epochs
from sklearn.feature_selection import mutual_info_regression
from joblib import Parallel, delayed # For parallelizing MI computation
from scipy.stats import pearsonr # Import pearsonr for Pearson's correlation

# --- Connectivity Computation Functions ---

def compute_spectral_connectivity_per_epoch(epochs, method, fmin=4, fmax=35):
    """
    Computes spectral connectivity (Coherence or DPLI) for each epoch.

    Parameters
    ----------
    epochs : mne.Epochs
        The MNE Epochs object containing the EEG data.
    method : str
        The connectivity method to compute ('coh' for Coherence, 'dpli' for DPLI).
    fmin : int
        Minimum frequency of the band of interest.
    fmax : int
        Maximum frequency of the band of interest.

    Returns
    -------
    numpy.ndarray
        An array of connectivity matrices, shape (n_epochs, n_channels, n_channels).
    """
    n_epochs = len(epochs)
    sfreq = epochs.info['sfreq']
    n_channels = len(epochs.ch_names)

    all_connectivity_matrices = []

    for i in range(n_epochs):
        single_epoch = epochs[i:i+1]

        con = spectral_connectivity_epochs(
            single_epoch,
            method=method,
            mode='multitaper',
            sfreq=sfreq,
            fmin=fmin,
            fmax=fmax,
            faverage=True,
            verbose=False,
            n_jobs=1 # MNE's n_jobs for this specific function
        )

        data = con.get_data(output='dense')

        if data.ndim == 3:
            connectivity_matrix = data[:, :, 0]
        else:
            connectivity_matrix = data

        # Apply corrections based on the connectivity method
        if method == 'coh':
            # Coherence matrices should be symmetric and have 1.0 on the diagonal.
            connectivity_matrix = np.maximum(connectivity_matrix, connectivity_matrix.T)
            np.fill_diagonal(connectivity_matrix, 1.0)
        elif method == 'dpli':
            # DPLI is based on phase lag, so diagonal should be 0.
            # DPLI can be slightly asymmetric due to implementation, but typically treated as undirected for graph analysis.
            # Forcing symmetry by taking the maximum ensures a consistent undirected graph representation.
            connectivity_matrix = np.maximum(connectivity_matrix, connectivity_matrix.T)
            np.fill_diagonal(connectivity_matrix, 0.0)

        all_connectivity_matrices.append(connectivity_matrix)

    all_connectivity_matrices = np.array(all_connectivity_matrices)
    return all_connectivity_matrices

def compute_mi_per_epoch_single(epoch_data, n_channels, n_neighbors=3):
    """
    Helper function to compute Mutual Information for a single epoch.
    Designed to be used with joblib.
    """
    mi_matrix = np.zeros((n_channels, n_channels))
    
    # Iterate through unique pairs (upper triangle) to avoid redundant calculations
    for ch1_idx in range(n_channels):
        for ch2_idx in range(ch1_idx + 1, n_channels): # Start from ch1_idx + 1 for unique pairs
            X = epoch_data[ch1_idx, :].reshape(-1, 1) # Reshape for sklearn
            Y = epoch_data[ch2_idx, :]

            # Using k-nearest neighbors method for MI estimation
            mi = mutual_info_regression(X, Y, n_neighbors=n_neighbors, random_state=42)[0] # Add random_state for reproducibility
            mi_matrix[ch1_idx, ch2_idx] = mi
            mi_matrix[ch2_idx, ch1_idx] = mi # Ensure symmetry

    return mi_matrix


def compute_mutual_information_per_epoch(epochs, n_neighbors=3, n_jobs=-1):
    """
    Computes Mutual Information (MI) between all pairs of channels for each epoch,
    using parallel processing for epochs.

    Parameters
    ----------
    epochs : mne.Epochs
        The MNE Epochs object containing the data.
    n_neighbors : int, optional
        Number of nearest neighbors to use for MI estimation. The default is 3.
    n_jobs : int, optional
        Number of CPU cores to use for parallel processing across epochs.
        -1 means use all available cores. The default is -1.

    Returns
    -------
    all_mi : numpy.ndarray
        An array of MI matrices, with shape (n_epochs, n_channels, n_channels).
        The diagonal will be 0 as MI of a signal with itself is not typically
        calculated in this context or would be infinite for continuous data.
    """
    n_epochs = len(epochs)
    n_channels = len(epochs.ch_names)

    print(f"   Starting parallel MI computation for {n_epochs} epochs with {n_jobs} cores...")

    # Extract all epoch data into a single NumPy array for efficient processing
    all_epochs_data = epochs.get_data(picks='eeg') # Shape: (n_epochs, n_channels, n_times)

    # Use joblib to parallelize the computation across epochs
    all_mi_matrices = Parallel(n_jobs=n_jobs)(
        delayed(compute_mi_per_epoch_single)(all_epochs_data[i], n_channels, n_neighbors)
        for i in range(n_epochs)
    )

    all_mi_matrices = np.array(all_mi_matrices)
    
    # Ensure diagonal is zero for all MI matrices
    for mi_matrix in all_mi_matrices:
        np.fill_diagonal(mi_matrix, 0.0)

    return all_mi_matrices

# --- NEW FUNCTION FOR PEARSON'S CORRELATION ---
def compute_pearson_correlation_per_epoch(epochs):
    """
    Computes Pearson's Correlation Coefficient for each epoch.

    Parameters
    ----------
    epochs : mne.Epochs
        The MNE Epochs object containing the EEG data.

    Returns
    -------
    numpy.ndarray
        An array of Pearson correlation matrices, shape (n_epochs, n_channels, n_channels).
        The diagonal will be 1.0 (correlation of a signal with itself).
    """
    n_epochs = len(epochs)
    n_channels = len(epochs.ch_names)
    all_correlation_matrices = []

    print(f"   Computing Pearson's Correlation for {n_epochs} epochs...")

    for i, epoch_data in enumerate(epochs.get_data(picks='eeg')): # epoch_data shape: (channels, time_points)
        correlation_matrix = np.eye(n_channels) # Initialize with identity for self-correlation = 1

        for ch1_idx in range(n_channels):
            for ch2_idx in range(ch1_idx + 1, n_channels): # Only compute upper triangle (symmetric matrix)
                channel1_data = epoch_data[ch1_idx, :]
                channel2_data = epoch_data[ch2_idx, :]

                # Compute Pearson's r (we don't need the p-value here for the matrix)
                correlation_coefficient, _ = pearsonr(channel1_data, channel2_data)

                # Store in both (ch1, ch2) and (ch2, ch1) positions for symmetry
                correlation_matrix[ch1_idx, ch2_idx] = correlation_coefficient
                correlation_matrix[ch2_idx, ch1_idx] = correlation_coefficient
        
        all_correlation_matrices.append(correlation_matrix)
    
    return np.array(all_correlation_matrices)

# --- Store connectivity results for each subject ---
# This dictionary will now store COH, DPLI, MI, and Pearson's R for each subject and condition
all_connectivity_by_subject = {}

print("\n--- Starting connectivity computation (COH, DPLI, MI, and Pearson's R) for all subjects ---")
for subject_id, epochs_data in epochs_by_subject.items():
    print(f"\nComputing connectivity for {subject_id}...")

    epochs_att_current = epochs_data['epochs_att']
    epochs_inatt_current = epochs_data['epochs_inatt']

    # --- Compute COH ---
    coh_att = compute_spectral_connectivity_per_epoch(epochs_att_current, method='coh')
    print(f"   {subject_id} Attention COH shape: {coh_att.shape}")
    coh_inatt = compute_spectral_connectivity_per_epoch(epochs_inatt_current, method='coh')
    print(f"   {subject_id} Inattention COH shape: {coh_inatt.shape}")

    # --- Compute DPLI ---
    dpli_att = compute_spectral_connectivity_per_epoch(epochs_att_current, method='dpli')
    print(f"   {subject_id} Attention DPLI shape: {dpli_att.shape}")
    dpli_inatt = compute_spectral_connectivity_per_epoch(epochs_inatt_current, method='dpli')
    print(f"   {subject_id} Inattention DPLI shape: {dpli_inatt.shape}")

    # # --- Compute Mutual Information (MI) ---
    mi_att = compute_mutual_information_per_epoch(epochs_att_current)
    print(f"   {subject_id} Attention MI shape: {mi_att.shape}")
    mi_inatt = compute_mutual_information_per_epoch(epochs_inatt_current)
    print(f"   {subject_id} Inattention MI shape: {mi_inatt.shape}")

    # # --- NEW: Compute Pearson's Correlation (R) ---
    pearson_r_att = compute_pearson_correlation_per_epoch(epochs_att_current)
    print(f"   {subject_id} Attention Pearson's R shape: {pearson_r_att.shape}")
    pearson_r_inatt = compute_pearson_correlation_per_epoch(epochs_inatt_current)
    print(f"   {subject_id} Inattention Pearson's R shape: {pearson_r_inatt.shape}")


    # Store all results, including Pearson's R
    all_connectivity_by_subject[subject_id] = {
        'coh_att': coh_att,
        'dpli_att': dpli_att,
        'mi_att': mi_att,
        'pearson_r_att': pearson_r_att, # Added Pearson's R here
        'coh_inatt': coh_inatt,
        'dpli_inatt': dpli_inatt,
        'mi_inatt': mi_inatt,
        'pearson_r_inatt': pearson_r_inatt # Added Pearson's R here
    }

print("\n--- Finished connectivity computation for all subjects ---")

# --- Example of how to access the results (you can add visualization here) ---
# For a specific subject and condition:
# subject_example = 'S01'
# if subject_example in all_connectivity_by_subject:
#     # Access the array of Pearson's R matrices for Attention
#     pearson_r_matrices_att_s01 = all_connectivity_by_subject[subject_example]['pearson_r_att']
#     # You can then average them, or use them as per-epoch features
#     avg_pearson_r_att_s01 = np.mean(pearson_r_matrices_att_s01, axis=0)
#     print(f"\nAverage Pearson's R matrix for {subject_example} (Attention):\n{avg_pearson_r_att_s01}")

In [None]:

import numpy as np
import networkx as nx # Assuming you have networkx imported for graph operations

# --- Function to apply proportional thresholding to a connectivity matrix ---
def proportional_threshold(matrix, density):
    """
    Applies proportional thresholding to a connectivity matrix.
    Assumes the input matrix is already a full, symmetric (or effectively symmetric for graph purposes)
    matrix with its diagonal already set to the desired value (e.g., 0 or 1).
    Retains edges corresponding to the top 'density' proportion of absolute edge weights.

    Parameters
    ----------
    matrix : numpy.ndarray
        A 2D square connectivity matrix.
    density : float
        The desired density of the resulting graph (proportion of edges to keep).
        Must be strictly between 0 and 1.

    Returns
    -------
    numpy.ndarray
        The thresholded matrix with values below the threshold set to 0.
    """
    if not (0 < density < 1):
        raise ValueError("Density must be strictly between 0 and 1.")

    # Use absolute values for thresholding, as weights are typically positive,
    # or their magnitude is what matters for connectivity strength.
    # We assume the matrix is already effectively symmetric and diagonal handled.
    connectivity_values = np.abs(matrix)

    # 1. Zero out the diagonal (if not already done or if you want to ensure it's 0 for thresholding)
    # This is a good practice for graph connectivity to avoid self-loops influencing threshold.
    np.fill_diagonal(connectivity_values, 0)

    # 2. Extract unique off-diagonal elements from the matrix
    # We only need the upper triangle because the matrix is already assumed symmetric for thresholding.
    upper_tri_elements = connectivity_values[np.triu_indices_from(connectivity_values, k=1)]

    if len(upper_tri_elements) == 0:
        return np.zeros_like(matrix) # Handle case where matrix is all zeros or no unique edges

    # 3. Determine the threshold value
    # Sort weights in descending order and pick the value at the desired percentile
    sorted_weights = np.sort(upper_tri_elements)[::-1] # Sort descending
    
    num_edges_to_keep = int(len(sorted_weights) * density)

    if num_edges_to_keep == 0: # If density is too low to keep any edges
        return np.zeros_like(matrix)

    # The threshold is the value at the (num_edges_to_keep - 1)-th position in the sorted list (0-indexed)
    threshold_value = sorted_weights[num_edges_to_keep - 1]

    # 4. Apply thresholding: set values below threshold to 0
    # Also, ensure values exactly equal to threshold are included
    thresholded_mat = np.where(connectivity_values >= threshold_value, connectivity_values, 0)
    
    return thresholded_mat


def threshold_and_convert_to_graph(connectivity_matrices, density=0.4):
    """
    Applies proportional thresholding to a list of connectivity matrices
    and converts each to a NetworkX graph, retaining weights.

    Parameters
    ----------
    connectivity_matrices : list of np.ndarray or np.ndarray (if 3D array of matrices)
        A list or array of connectivity matrices (e.g., from compute_connectivity_per_epoch).
    density : float, optional
        The desired density of the resulting graphs (proportion of edges to keep).
        Must be between 0 and 1. Default is 0.4.

    Returns
    -------
    list of networkx.Graph
        A list of NetworkX graphs, one for each input matrix.
    """
    graphs = []
    # Ensure connectivity_matrices is iterable, even if it's a single 3D array
    if isinstance(connectivity_matrices, np.ndarray) and connectivity_matrices.ndim == 3:
        matrices_to_process = connectivity_matrices
    else:
        matrices_to_process = list(connectivity_matrices) # Ensure it's a list

    for matrix in matrices_to_process:
        # proportional_threshold handles abs and symmetry internally
        thresholded = proportional_threshold(matrix, density)
        
        # Convert to graph, retaining weights. Ensure it's undirected.
        # nx.from_numpy_array will create a weighted graph if the matrix contains non-binary values.
        G = nx.from_numpy_array(thresholded)
        
        # Remove self-loops (proportional_threshold should already set diagonal to 0,
        # but this is a good safeguard).
        G.remove_edges_from(nx.selfloop_edges(G))
        
        graphs.append(G)
    return graphs


# Initialize a dictionary to store all types of graphs per subject
# This will have a nested structure: graphs_by_subject[subject_id][connectivity_type_and_condition]
graphs_by_subject = {}

# --- Define a dictionary for preferred densities for each method ---
# ADDED ENTRIES FOR PEARSON'S R
method_densities = {
    'coh_att': 0.5,
    'dpli_att': 0.4,
    'mi_att': 0.4,
    'pearson_r_att': 0.4, # Added Pearson's R density for attention
    'coh_inatt': 0.5,
    'dpli_inatt': 0.4,
    'mi_inatt': 0.4,
    'pearson_r_inatt': 0.4, # Added Pearson's R density for inattention
}

# Iterate through each subject's data from the all_connectivity_by_subject dictionary
print("\n--- Starting graph conversion for all subjects and connectivity types ---")
for subject_id, data in all_connectivity_by_subject.items():
    print(f"\nProcessing graphs for {subject_id}...")
    
    # Initialize a sub-dictionary for the current subject's graphs
    graphs_by_subject[subject_id] = {}

    # Iterate through each connectivity type and its corresponding matrices for the current subject
    # 'data.items()' will give you (key, value) pairs like ('coh_att', [coh_att_matrices])
    for conn_type_key, connectivity_matrices_list in data.items():
        # Get the appropriate density for the current connectivity type
        # If the key is not in method_densities, it will use the default of 0.4
        current_density = method_densities.get(conn_type_key, 0.4) 

        print(f"   Processing {subject_id}: {conn_type_key} with density {current_density}...")
        
        # Apply thresholding and convert to graphs
        current_graphs = threshold_and_convert_to_graph(connectivity_matrices_list, density=current_density)
        
        # Store the list of graphs under the appropriate key for the subject
        graphs_by_subject[subject_id][conn_type_key] = current_graphs
        
        print(f"    {subject_id}: {len(current_graphs)} {conn_type_key} graphs created.")

print("\n--- Finished graph conversion for all subjects ---")

# Example of how to access the processed graphs:
# subject_id_example = list(graphs_by_subject.keys())[0] # Get the first subject ID
# first_coh_att_graph = graphs_by_subject[subject_id_example]['coh_att'][0]
# print(f"\nExample: First COH Attention graph for {subject_id_example}: {first_coh_att_graph}")
# print(f"Number of nodes: {first_coh_att_graph.number_of_nodes()}, Number of edges: {first_coh_att_graph.number_of_edges()}")

# Accessing a Pearson's R graph:
# if 'pearson_r_att' in graphs_by_subject[subject_id_example]:
#     first_pearson_r_att_graph = graphs_by_subject[subject_id_example]['pearson_r_att'][0]
#     print(f"\nExample: First Pearson's R Attention graph for {subject_id_example}: {first_pearson_r_att_graph}")
#     print(f"Number of nodes: {first_pearson_r_att_graph.number_of_nodes()}, Number of edges: {first_pearson_r_att_graph.number_of_edges()}")

In [None]:

import networkx as nx
import numpy as np
import community as community_louvain # Make sure this is installed (pip install python-louvain)
from sklearn.impute import SimpleImputer # For handling NaNs

# --- Updated extract_graph_features function ---
def extract_graph_features(graphs):
    """
    Extract traditional and advanced graph features from a list of NetworkX graphs.
    Removed features that consistently returned NaN values based on previous analysis.
    """
    all_features = []
    feature_names = [
        'num_nodes', 'num_edges', 'density',
        'total_weight', 'avg_weight', 'min_weight', 'max_weight', 'std_weight',
        'char_path_length', 'global_efficiency', 'diameter',
        'avg_clustering_weighted', 'transitivity',
        'avg_degree_centrality', 'avg_betweenness_centrality', 'avg_closeness_centrality',
        'assortativity_degree', 'avg_weighted_degree',
        'modularity', 'num_communities', 'max_community_size',
        'spectral_radius', 'algebraic_connectivity',
        'num_connected_components', 'size_largest_component_ratio',
        'gini_coeff_degree', 'entropy_degree',
        'avg_eccentricity', 'periphery_nodes_ratio', 'center_nodes_ratio',
        'avg_k_core_size', 'max_k_core_size',
        'node_connectivity', 'edge_connectivity',
        'avg_shortest_path_unweighted',
        'avg_degree_unweighted'
    ] # Total 36 static features

    for G in graphs:
        feat = {name: np.nan for name in feature_names} # Initialize with NaN for all features

        # Handle empty graphs or graphs with a single node
        if not G or G.number_of_nodes() == 0:
            feat['num_nodes'] = 0
            feat['num_edges'] = 0
            all_features.append(list(feat.values()))
            continue
        if G.number_of_nodes() == 1:
            feat['num_nodes'] = 1
            feat['num_edges'] = 0
            all_features.append(list(feat.values()))
            continue

        feat['num_nodes'] = G.number_of_nodes()
        feat['num_edges'] = G.number_of_edges()
        feat['density'] = nx.density(G)

        # Weighted features
        if G.number_of_edges() > 0:
            weights = [data['weight'] for u, v, data in G.edges(data=True)]
            feat['total_weight'] = sum(weights)
            feat['avg_weight'] = np.mean(weights)
            feat['min_weight'] = np.min(weights)
            feat['max_weight'] = np.max(weights)
            feat['std_weight'] = np.std(weights)
        else:
            feat['total_weight'] = 0
            feat['avg_weight'] = 0
            feat['min_weight'] = np.nan
            feat['max_weight'] = np.nan
            feat['std_weight'] = np.nan

        # Largest Connected Component (LCC) for path-based metrics
        is_conn = nx.is_connected(G)
        if not is_conn:
            components = list(nx.connected_components(G))
            G_lcc = G.subgraph(max(components, key=len))
            feat['num_connected_components'] = len(components)
            feat['size_largest_component_ratio'] = G_lcc.number_of_nodes() / G.number_of_nodes()
        else:
            G_lcc = G
            feat['num_connected_components'] = 1
            feat['size_largest_component_ratio'] = 1.0

        # Path-based features (on LCC for robustness)
        if G_lcc.number_of_nodes() > 1 and G_lcc.number_of_edges() > 0:
            try:
                feat['char_path_length'] = nx.average_shortest_path_length(G_lcc, weight='weight')
                feat['avg_shortest_path_unweighted'] = nx.average_shortest_path_length(G_lcc)
            except (nx.NetworkXError, nx.NetworkXPointlessConcept):
                feat['char_path_length'] = np.nan
                feat['avg_shortest_path_unweighted'] = np.nan

            if not np.isnan(feat['char_path_length']) and feat['char_path_length'] != 0:
                feat['global_efficiency'] = 1 / feat['char_path_length']
            else:
                feat['global_efficiency'] = np.nan

            try:
                feat['diameter'] = nx.diameter(G_lcc)
                eccentricities = nx.eccentricity(G_lcc)
                feat['avg_eccentricity'] = np.mean(list(eccentricities.values()))
                periphery_nodes = nx.periphery(G_lcc)
                center_nodes = nx.center(G_lcc)
                feat['periphery_nodes_ratio'] = len(periphery_nodes) / G_lcc.number_of_nodes()
                feat['center_nodes_ratio'] = len(center_nodes) / G_lcc.number_of_nodes()

            except (nx.NetworkXError, nx.NetworkXPointlessConcept):
                feat['diameter'] = np.nan
                feat['avg_eccentricity'] = np.nan
                feat['periphery_nodes_ratio'] = np.nan
                feat['center_nodes_ratio'] = np.nan
        else:
            feat['char_path_length'] = np.nan
            feat['global_efficiency'] = np.nan
            feat['diameter'] = np.nan
            feat['avg_eccentricity'] = np.nan
            feat['periphery_nodes_ratio'] = np.nan
            feat['center_nodes_ratio'] = np.nan
            feat['avg_shortest_path_unweighted'] = np.nan

        feat['avg_clustering_weighted'] = nx.average_clustering(G, weight='weight')
        feat['transitivity'] = nx.transitivity(G)

        # Removed 'avg_local_efficiency' as it consistently gave NaN

        # Centrality measures (weighted averages)
        if G.number_of_nodes() > 0:
            degree_centrality = nx.degree_centrality(G)
            feat['avg_degree_centrality'] = np.mean(list(degree_centrality.values()))

            if G.number_of_nodes() > 2:
                betweenness = nx.betweenness_centrality(G, weight='weight')
                feat['avg_betweenness_centrality'] = np.mean(list(betweenness.values()))

                closeness = nx.closeness_centrality(G, distance='weight')
                feat['avg_closeness_centrality'] = np.mean(list(closeness.values()))
            else:
                feat['avg_betweenness_centrality'] = np.nan
                feat['avg_closeness_centrality'] = np.nan
        else:
            feat.update({'avg_degree_centrality': np.nan, 'avg_betweenness_centrality': np.nan, 'avg_closeness_centrality': np.nan})

        # Degree-related features (for weighted and unweighted degrees)
        degrees = dict(G.degree(weight='weight'))
        if G.number_of_nodes() > 0:
            feat['avg_weighted_degree'] = sum(degrees.values()) / G.number_of_nodes()
            feat['avg_degree_unweighted'] = np.mean(list(dict(G.degree()).values()))
        else:
            feat['avg_weighted_degree'] = 0
            feat['avg_degree_unweighted'] = 0

        # Assortativity
        if G.number_of_edges() > 0:
            try:
                feat['assortativity_degree'] = nx.degree_assortativity_coefficient(G, weight='weight')
            except Exception:
                feat['assortativity_degree'] = np.nan
        else:
            feat['assortativity_degree'] = np.nan

        # Community structure features (using Louvain method)
        try:
            if G.number_of_edges() > 0 and G.number_of_nodes() > 1:
                partition = community_louvain.best_partition(G, weight='weight')
                feat['modularity'] = community_louvain.modularity(partition, G, weight='weight')
                num_communities = len(set(partition.values()))
                feat['num_communities'] = num_communities
                if num_communities > 0:
                    community_sizes = [list(partition.values()).count(c) for c in set(partition.values())]
                    feat['max_community_size'] = np.max(community_sizes)
                else:
                    feat['max_community_size'] = np.nan
            else:
                feat['modularity'] = np.nan
                feat['num_communities'] = np.nan
                feat['max_community_size'] = np.nan
        except Exception as e:
            feat['modularity'] = np.nan
            feat['num_communities'] = np.nan
            feat['max_community_size'] = np.nan

        # Spectral features (eigenvalues of graph Laplacian)
        if G.number_of_nodes() > 1:
            try:
                L = nx.normalized_laplacian_matrix(G, weight='weight')
                eigenvalues = np.linalg.eigvalsh(L.toarray())
                feat['spectral_radius'] = np.max(eigenvalues)
                sorted_eigenvalues = np.sort(eigenvalues)
                algebraic_connectivity = next((val for val in sorted_eigenvalues if val > 1e-9), 0)
                feat['algebraic_connectivity'] = algebraic_connectivity
            except Exception:
                feat['spectral_radius'] = np.nan
                feat['algebraic_connectivity'] = np.nan
        else:
            feat['spectral_radius'] = np.nan
            feat['algebraic_connectivity'] = np.nan

        # Distributional features (e.g., for degree distribution)
        if G.number_of_nodes() > 1:
            degrees_values = np.array(list(dict(G.degree(weight='weight')).values()))
            if len(degrees_values) > 1:
                sorted_degrees = np.sort(degrees_values)
                n = len(sorted_degrees)
                numerator = np.sum([(2 * (i + 1) - n - 1) * sorted_degrees[i] for i in range(n)])
                denominator = n**2 * np.mean(sorted_degrees)
                feat['gini_coeff_degree'] = numerator / denominator if denominator != 0 else np.nan

                counts, bins = np.histogram(degrees_values, bins='auto', density=True)
                counts = counts[counts > 0]
                feat['entropy_degree'] = -np.sum(counts * np.log2(counts))
            else:
                feat['gini_coeff_degree'] = np.nan
                feat['entropy_degree'] = np.nan
        else:
            feat['gini_coeff_degree'] = np.nan
            feat['entropy_degree'] = np.nan

        # Core-Periphery Measures
        if G.number_of_nodes() > 1:
            try:
                k_core = nx.core_number(G)
                if k_core:
                    feat['avg_k_core_size'] = np.mean(list(k_core.values()))
                    feat['max_k_core_size'] = np.max(list(k_core.values()))
                else:
                    feat['avg_k_core_size'] = np.nan
                    feat['max_k_core_size'] = np.nan
            except Exception:
                feat['avg_k_core_size'] = np.nan
                feat['max_k_core_size'] = np.nan
        else:
            feat['avg_k_core_size'] = np.nan
            feat['max_k_core_size'] = np.nan

        # Removed 'rich_club_coeff', 'small_world_omega', 'small_world_sigma' calculation blocks

        # Robustness Measures
        if G.number_of_nodes() > 1 and nx.is_connected(G):
            try:
                feat['node_connectivity'] = nx.node_connectivity(G)
            except nx.NetworkXNoPath:
                feat['node_connectivity'] = 0
            except nx.NetworkXError:
                feat['node_connectivity'] = np.nan
            try:
                feat['edge_connectivity'] = nx.edge_connectivity(G)
            except nx.NetworkXNoPath:
                feat['edge_connectivity'] = 0
            except nx.NetworkXError:
                feat['edge_connectivity'] = np.nan
        else:
            feat['node_connectivity'] = np.nan
            feat['edge_connectivity'] = np.nan

        # Ensure the order of features appended to all_features matches feature_names
        ordered_features = [feat[name] for name in feature_names]
        all_features.append(ordered_features)

    return np.array(all_features), feature_names

# --- Updated add_dynamic_features function ---
def add_dynamic_features(graphs):
    """
    Add temporal variability features from a list of NetworkX graphs.
    Assumes 'graphs' is a list of graphs, ordered by time/epoch.
    Each feature is calculated per epoch/graph, potentially relative to the previous.
    """
    per_epoch_dynamic_features = [] # Initialize the list here!
    dynamic_feature_names = [
        'current_epoch_edge_weight_variability',
        'current_epoch_mean_edge_weight',
        'current_epoch_strong_connections_count',
        'reconfiguration_jaccard_index',
        'frobenius_norm_diff',
        'delta_mean_degree',
        'delta_mean_clustering'
    ] # Total 7 dynamic features

    if not graphs:
        # If no graphs are provided, return NaNs for all dynamic features
        return np.full((1, len(dynamic_feature_names)), np.nan), dynamic_feature_names

    for i, G in enumerate(graphs):
        # Features from your original add_dynamic_features, now explicitly per epoch
        weights = [G[u][v]['weight'] for u, v in G.edges()]
        if len(weights) > 0:
            ew_variability = np.std(weights) if len(weights) > 1 else 0
            mean_ew = np.mean(weights)
            strong_conn_count = sum(1 for w in weights if w > 0.5) # Example threshold
        else:
            ew_variability = np.nan
            mean_ew = np.nan
            strong_conn_count = np.nan

        # New dynamic features based on temporal evolution (comparison to previous epoch)
        jaccard_index = np.nan
        frobenius_diff = np.nan
        delta_mean_degree = np.nan
        delta_mean_clustering = np.nan

        if i > 0: # Only calculate these for 2nd epoch onwards
            G_prev = graphs[i-1]

            # Reconfiguration Index: Jaccard similarity of edge sets (unweighted)
            edges_current = set(G.edges())
            edges_prev = set(G_prev.edges())
            union_edges = len(edges_current.union(edges_prev))
            if union_edges > 0:
                intersection_edges = len(edges_current.intersection(edges_prev))
                jaccard_index = 1 - (intersection_edges / union_edges) # 1 - similarity = dissimilarity
            else: # Both empty or one empty
                jaccard_index = 0 if not edges_current and not edges_prev else np.nan

            # Frobenius Norm of Adjacency Matrix Difference
            # Assumes consistent node ordering across epochs
            nodes = sorted(list(G.nodes())) # Ensure consistent node order for matrix conversion
            if not nodes: # Handle case of empty graphs
                frobenius_diff = np.nan
            else:
                adj_current = nx.to_numpy_array(G, nodelist=nodes, weight='weight')
                adj_prev = nx.to_numpy_array(G_prev, nodelist=nodes, weight='weight')

                if adj_current.shape == adj_prev.shape:
                    frobenius_diff = np.linalg.norm(adj_current - adj_prev, 'fro')
                else:
                    frobenius_diff = np.nan # Should ideally not happen with fixed channels

            # Changes in global metrics (simple difference from previous epoch)
            current_avg_degree = np.mean(list(dict(G.degree(weight='weight')).values())) if G.number_of_nodes() > 0 else np.nan
            prev_avg_degree = np.mean(list(dict(G_prev.degree(weight='weight')).values())) if G_prev.number_of_nodes() > 0 else np.nan
            delta_mean_degree = current_avg_degree - prev_avg_degree if not np.isnan(current_avg_degree) and not np.isnan(prev_avg_degree) else np.nan

            current_avg_clustering = nx.average_clustering(G, weight='weight') if G.number_of_nodes() > 0 else np.nan
            prev_avg_clustering = nx.average_clustering(G_prev, weight='weight') if G_prev.number_of_nodes() > 0 else np.nan
            delta_mean_clustering = current_avg_clustering - prev_avg_clustering if not np.isnan(current_avg_clustering) and not np.isnan(prev_avg_clustering) else np.nan

        per_epoch_dynamic_features.append([
            ew_variability,
            mean_ew,
            strong_conn_count,
            jaccard_index,
            frobenius_diff,
            delta_mean_degree,
            delta_mean_clustering
        ])

    return np.array(per_epoch_dynamic_features), dynamic_feature_names



In [None]:
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from scipy.stats import ttest_ind
import matplotlib.pyplot as plt
from statsmodels.stats.multitest import multipletests
import pandas as pd
import seaborn as sns



analysis_results = {}

# Iterate over each connectivity type (COH, DPLI, and MI)
# --- The only change needed is adding 'mi' to this list ---
for conn_type in ['coh', 'mi','dpli','pearson_r']: # ADDED 'mi' here!
    print(f"\n--- Starting analysis for {conn_type.upper()} connectivity ---")

    # Initialize dictionaries for the current connectivity type
    features_by_subject = {}
    labels_by_subject = {}
    combined_features_by_subject = {}
    combined_labels_by_subject = {}
    all_feature_names = [] # This will be populated once for each conn_type



    # Before the loop, check if graphs_by_subject exists and is not empty.
    # If it's empty or not defined, the loop over its keys won't run.
    if not graphs_by_subject:
        print("Error: 'graphs_by_subject' is empty or not defined. Please ensure graph conversion ran successfully.")
        continue # Skip this conn_type if no data to process

    for subject_id in graphs_by_subject.keys():
        # Access graphs specific to the current connectivity type and condition
        # This line automatically picks up 'mi_att' and 'mi_inatt' when conn_type is 'mi'
        graphs_att = graphs_by_subject[subject_id].get(f'{conn_type}_att')
        graphs_inatt = graphs_by_subject[subject_id].get(f'{conn_type}_inatt')

        # Handle cases where a specific connectivity type might not exist for a subject
        if graphs_att is None or graphs_inatt is None:
            print(f"Warning: {conn_type.upper()} data not found for subject {subject_id}. Skipping this subject for {conn_type}.")
            continue

        # Extract static features
        X_att_static, static_names = extract_graph_features(graphs_att)
        X_inatt_static, _ = extract_graph_features(graphs_inatt)

        # Extract dynamic features
        X_att_dyn, dynamic_names = add_dynamic_features(graphs_att)
        X_inatt_dyn, _ = add_dynamic_features(graphs_inatt)

        # Populate feature names only once (assuming feature names are consistent across subjects)
        if not all_feature_names:
            # Add prefix to feature names here for clarity
            prefixed_static_names = [f"{conn_type}_{name}" for name in static_names]
            prefixed_dynamic_names = [f"{conn_type}_{name}" for name in dynamic_names]
            all_feature_names = prefixed_static_names + prefixed_dynamic_names

        # Combine features
        # Robustly combine static and dynamic features, handling cases where one might be empty
        # A subject could have no features if, for instance, `extract_graph_features` returns empty for sparse graphs.
        current_X_att = []
        current_X_inatt = []

        if X_att_static.shape[1] > 0:
            current_X_att.append(X_att_static)
            current_X_inatt.append(X_inatt_static)
        if X_att_dyn.shape[1] > 0:
            current_X_att.append(X_att_dyn)
            current_X_inatt.append(X_inatt_dyn)

        if not current_X_att or not current_X_inatt:
            print(f"Warning: No features (static or dynamic) extracted for {subject_id} for {conn_type}. Skipping subject.")
            continue # Skip to next subject if no features are extracted

        X_att = np.hstack(current_X_att)
        X_inatt = np.hstack(current_X_inatt)

        # Impute missing values
        # Fit imputer on combined data for consistent scaling and imputation
        imputer = SimpleImputer(strategy='mean')
        X_combined_for_imputation = np.vstack((X_att, X_inatt))

        # Check if there are any features to impute
        if X_combined_for_imputation.shape[1] == 0:
            print(f"Warning: No features to impute for {subject_id} for {conn_type}. Skipping imputation for this subject.")
            # Assign empty arrays if no features
            X_att_imputed = X_att
            X_inatt_imputed = X_inatt
        else:
            X_combined_imputed = imputer.fit_transform(X_combined_for_imputation)
            n_att = X_att.shape[0] # Use original shape for splitting
            X_att_imputed = X_combined_imputed[:n_att]
            X_inatt_imputed = X_combined_imputed[n_att:]


        # Z-score normalization across both conditions
        scaler = StandardScaler()
        # Check if there are any features to scale
        if X_att_imputed.shape[1] == 0:
            print(f"Warning: No features to scale for {subject_id} for {conn_type}. Skipping scaling for this subject.")
            X_att_scaled = X_att_imputed
            X_inatt_scaled = X_inatt_imputed
        else:
            X_combined_scaled = scaler.fit_transform(np.vstack((X_att_imputed, X_inatt_imputed)))
            n_att = X_att_imputed.shape[0] # Use imputed shape for splitting
            X_att_scaled = X_combined_scaled[:n_att]
            X_inatt_scaled = X_combined_scaled[n_att:]


        features_by_subject[subject_id] = {
            'attention': X_att_scaled,
            'inattention': X_inatt_scaled
        }

        labels_att = np.ones(X_att_scaled.shape[0])
        labels_inatt = np.zeros(X_inatt_scaled.shape[0])

        labels_by_subject[subject_id] = {
            'attention': labels_att,
            'inattention': labels_inatt
        }

        # Combine for t-test usage
        combined_feats = np.vstack((X_att_scaled, X_inatt_scaled))
        combined_labels = np.concatenate((labels_att, labels_inatt))

        combined_features_by_subject[subject_id] = combined_feats
        combined_labels_by_subject[subject_id] = combined_labels

    print(f"Extracted features for {len(features_by_subject)} subjects for {conn_type.upper()}.")
    print(f"Total number of features extracted for {conn_type.upper()}: {len(all_feature_names)}")

    # === Combine all subjects' data for t-test ===
    # Check if any features were extracted at all for this connectivity type
    if not features_by_subject:
        print(f"No features available for {conn_type.upper()}. Skipping statistical analysis for this type.")
        analysis_results[conn_type] = {
            't_values': np.array([]), 'p_values': np.array([]),
            'corrected_pvals': np.array([]), 'significant_mask': np.array([]),
            'cohens_d_values': []
        } # Store empty arrays to avoid key errors later
        continue # Skip to the next connectivity type

    all_feats = []
    all_labels = []

    for subject_id in combined_features_by_subject:
        if combined_features_by_subject[subject_id].shape[0] > 0: # Only add if data exists for subject
            all_feats.append(combined_features_by_subject[subject_id])
            all_labels.append(combined_labels_by_subject[subject_id])
    
    if not all_feats:
        print(f"No combined features from any subject for {conn_type.upper()}. Skipping statistical analysis.")
        analysis_results[conn_type] = {
            't_values': np.array([]), 'p_values': np.array([]),
            'corrected_pvals': np.array([]), 'significant_mask': np.array([]),
            'cohens_d_values': []
        }
        continue

    X_all = np.vstack(all_feats)
    y_all = np.concatenate(all_labels)

    print(f"Combined data shape for {conn_type.upper()}: {X_all.shape}, Labels shape: {y_all.shape}")

    # Separate attention and inattention data
    X_att = X_all[y_all == 1]
    X_inatt = X_all[y_all == 0]

    # Handle any residual NaNs (safety)
    X_att = np.nan_to_num(X_att, nan=0.0)
    X_inatt = np.nan_to_num(X_inatt, nan=0.0)

    # === Perform t-test ===
    # Check if there are enough samples and features for t-test
    if X_att.shape[0] < 2 or X_inatt.shape[0] < 2 or X_all.shape[1] == 0:
        print(f"Not enough samples ({X_att.shape[0]} att, {X_inatt.shape[0]} inatt) or no features ({X_all.shape[1]}) for t-test for {conn_type.upper()}. Skipping.")
        t_values = np.array([])
        p_values = np.array([])
        significant_mask = np.array([])
        corrected_pvals = np.array([])
    else:
        t_values, p_values = ttest_ind(X_att, X_inatt, axis=0, equal_var=False)

        # Bonferroni correction
        if len(p_values) == 0: # Check if p_values is empty (e.g., no features)
            significant_mask = np.array([])
            corrected_pvals = np.array([])
        else:
            significant_mask, corrected_pvals, _, _ = multipletests(p_values, alpha=0.05, method='bonferroni')
    
    n_significant = np.sum(significant_mask)
    print(f"\nNumber of significant features for {conn_type.upper()}: {n_significant} / {len(p_values)}")

    # === Show top significant features ===
    top_k = 10
    top_indices = []
    if len(p_values) > 0: # Only proceed if there are p-values
        # Sort by p-value
        sorted_p_indices = np.argsort(p_values)
        
        # Filter for significant features and take top_k
        significant_indices_sorted = [idx for idx in sorted_p_indices if significant_mask[idx]]
        
        if len(significant_indices_sorted) > 0:
            top_indices = significant_indices_sorted[:top_k]
        elif len(all_feature_names) > 0: # If no significant features, take top_k overall (if features exist)
            top_indices = sorted_p_indices[:top_k]

    if top_indices:
        print(f"\nTop {top_k} features for {conn_type.upper()} (by p-value):")
        for idx in top_indices:
            print(f"   {idx:3d} | p={p_values[idx]:.4e} | Corrected p={corrected_pvals[idx]:.4e} | Significant: {significant_mask[idx]} | {all_feature_names[idx]}")
    else:
        print(f"No features to display for {conn_type.upper()}.")

    # === Plot p-values and significant features ===
    if len(p_values) > 0: # Only plot if there are p-values
        plt.figure(figsize=(12, 4))
        plt.plot(p_values, marker='o', linestyle='-', alpha=0.5, label='p-values')
        if n_significant > 0:
            plt.plot(np.where(significant_mask)[0], p_values[significant_mask], 'ro', label='Significant (Bonferroni)')
        plt.axhline(y=0.05, color='red', linestyle='--', label='p=0.05 (uncorrected)')
        if len(p_values) > 0:
            plt.axhline(y=0.05 / len(p_values), color='green', linestyle=':', label='Bonferroni threshold')
        plt.title(f"p-values for all features ({conn_type.upper()}: Attention vs Inattention)")
        plt.xlabel("Feature Index")
        plt.ylabel("p-value")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()
    else:
        print(f"No p-values to plot for {conn_type.upper()}.")

    # === Compute Cohen's d ===
    def compute_cohens_d(x1, x2):
        n1, n2 = len(x1), len(x2)
        # Handle cases where variance might be zero or very small to avoid division by zero
        var1 = np.var(x1, ddof=1)
        var2 = np.var(x2, ddof=1)
        
        pooled_std = np.sqrt(((n1-1)*var1 + (n2-1)*var2) / (n1 + n2 - 2 + 1e-9))
        
        if pooled_std == 0:
            return 0.0
        return (np.mean(x1) - np.mean(x2)) / pooled_std

    cohens_d_values = []
    if X_all.shape[1] > 0 and X_att.shape[0] > 0 and X_inatt.shape[0] > 0: # Only compute if there are features and data
        cohens_d_values = [compute_cohens_d(X_att[:, i], X_inatt[:, i]) for i in range(X_all.shape[1])]

    # === Store results for the current connectivity type ===
    analysis_results[conn_type] = {
        'features_by_subject': features_by_subject,
        'labels_by_subject': labels_by_subject,
        'combined_features_by_subject': combined_features_by_subject,
        'combined_labels_by_subject': combined_labels_by_subject,
        'all_feature_names': all_feature_names,
        't_values': t_values,
        'p_values': p_values,
        'corrected_pvals': corrected_pvals,
        'significant_mask': significant_mask,
        'cohens_d_values': cohens_d_values
    }

    # === Example for one feature (boxplot) ===
    # Only plot if there are features, feature names, and enough data points for both conditions
    if len(all_feature_names) > 0 and X_att.shape[0] > 0 and X_inatt.shape[0] > 0:
        example_feat_idx = -1
        significant_indices = np.where(analysis_results[conn_type]['significant_mask'])[0]
        
        if len(significant_indices) > 0:
            example_feat_idx = significant_indices[0] # Pick the first significant feature
        elif len(all_feature_names) > 0:
            example_feat_idx = 0 # If no significant, pick the first feature

        if example_feat_idx != -1 and example_feat_idx < X_att.shape[1]: # Ensure index is valid
            df = pd.DataFrame({
                'Feature Value': np.concatenate([X_att[:, example_feat_idx], X_inatt[:, example_feat_idx]]),
                'Label': ['Attention'] * len(X_att) + ['Inattention'] * len(X_inatt)
            })

            plt.figure(figsize=(6, 4))
            sns.boxplot(data=df, x='Label', y='Feature Value')
            plt.title(f'Distribution of {all_feature_names[example_feat_idx]} for {conn_type.upper()}')
            plt.grid(True)
            plt.show()
        else:
            print(f"Not enough data or valid feature index to plot example boxplot for {conn_type.upper()}.")
    else:
        print(f"No features or data to plot example boxplot for {conn_type.upper()}.")

print("\n--- Finished analysis for all connectivity types ---")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier,
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB, ComplementNB
from xgboost import XGBClassifier
from sklearn.tree import DecisionTreeClassifier # Needed for AdaBoost base_estimator

from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, roc_auc_score, roc_curve, auc
from sklearn.model_selection import RandomizedSearchCV
from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from scipy.stats import ttest_ind
import warnings
import pandas as pd
from collections import defaultdict

warnings.filterwarnings('ignore')




def ttest_feature_selection(X_train, y_train, X_test, alpha=0.05):
    """
    Performs feature selection using independent t-tests.
    Selects features where the p-value is below the given alpha threshold.
    Includes a fallback to select top 10 features if none pass the threshold.
    """
    X_att = X_train[y_train == 1]
    X_inatt = X_train[y_train == 0]

    # Handle cases where one class might be empty or not enough samples for t-test
    if X_att.shape[0] < 2 or X_inatt.shape[0] < 2:
        print("Warning: Not enough samples in one or both classes for t-test. Skipping t-test based selection, returning all features.")
        if X_train.shape[1] > 0:
            return X_train, X_test, np.arange(X_train.shape[1])
        else: # No features to begin with
            return np.array([]).reshape(X_train.shape[0], 0), \
                   np.array([]).reshape(X_test.shape[0], 0), \
                   np.array([])

    if X_train.shape[1] == 0:
        print("Warning: No features in input to t-test selection. Returning empty arrays.")
        return np.array([]).reshape(X_train.shape[0], 0), \
               np.array([]).reshape(X_test.shape[0], 0), \
               np.array([])

    t_vals, p_vals = ttest_ind(X_att, X_inatt, axis=0, equal_var=False)
    selected_indices = np.where(p_vals < alpha)[0]

    # Fallback if no features pass threshold or if t-test was skipped
    if len(selected_indices) == 0:
        print(f"Warning: No features passed t-test threshold (alpha={alpha}). Selecting top 10 features by p-value.")
        if p_vals.size > 0: # Ensure p_vals is not empty before sorting
            selected_indices = np.argsort(p_vals)[:min(10, p_vals.size)] # Select top 10 or all if less than 10
        else: # If no p-values (e.g., no features to begin with)
            selected_indices = np.array([])
        
        # Final fallback to select all features if nothing else worked and features exist
        if len(selected_indices) == 0 and X_train.shape[1] > 0:
            selected_indices = np.arange(X_train.shape[1])

    if X_train.shape[1] == 0 or len(selected_indices) == 0:
        print("Warning: No features left after t-test selection. Returning empty arrays.")
        return np.array([]).reshape(X_train.shape[0], 0), \
               np.array([]).reshape(X_test.shape[0], 0), \
               np.array([])

    return X_train[:, selected_indices], X_test[:, selected_indices], selected_indices

# Z-score normalize features per subject (not used in main loop, StandardScaler is)
def zscore_per_subject(X):
    if X.shape[0] == 0 or X.shape[1] == 0:
        return X
    
    mean = X.mean(axis=0, keepdims=True)
    std = X.std(axis=0, keepdims=True)
    std[std == 0] = 1 # Avoid division by zero
    return (X - mean) / std

# --- Models and hyperparameter grids (Optimized for speed) ---
models_and_params = {
    'LogisticRegression': (LogisticRegression(max_iter=2000, random_state=42), {
        'C': [0.001, 0.01, 0.1, 1], # Reduced range
        'penalty': ['l1', 'l2', 'elasticnet'],
        # 'saga' supports all penalties including elasticnet. 'liblinear' supports l1/l2.
        # 'lbfgs' only supports l2 and no l1_ratio.
        'solver': ['saga'], # Focus on saga for elasticnet
        'l1_ratio': [0.1, 0.5, 0.9], # For elasticnet
        'max_iter': [1000] # Single value, good enough if convergence is met
    }),
    "SGD Classifier": (SGDClassifier(random_state=42), {
        'loss': ['hinge', 'log_loss', 'modified_huber'], # 'log' is deprecated, use 'log_loss'
        'penalty': ['l1', 'l2', 'elasticnet'],
        'alpha': [0.0001, 0.001, 0.01], # Reduced range
        'l1_ratio': [0.3, 0.7],
        'learning_rate': ['constant', 'adaptive'],
        'eta0': [0.01, 0.1]
    }),
    'MLP': (MLPClassifier(max_iter=2000, random_state=42), {
        'hidden_layer_sizes': [
            (100,), (150,), # Single layer
            (100, 50) # Two layers
        ],
        'activation': ['relu', 'tanh'],
        'solver': ['adam'], # Adam is generally faster
        'alpha': [0.001, 0.01],
        'learning_rate': ['constant', 'adaptive'],
        'learning_rate_init': [0.0001, 0.001],
    }),
    'RandomForest': (RandomForestClassifier(random_state=42), {
        'n_estimators': [100, 200], # Fewer options
        'max_depth': [None, 10, 20],
        'min_samples_split': [2, 10],
        'min_samples_leaf': [1, 4],
        'max_features': ['sqrt', None], # Simplified options
        'bootstrap': [True],
        'criterion': ['gini']
    }),
    "Extra Trees": (ExtraTreesClassifier(random_state=42), {
        'n_estimators': [100, 200],
        'max_depth': [None, 20],
        'min_samples_split': [2, 10],
        'min_samples_leaf': [1, 4],
        'max_features': ['sqrt', None],
        'criterion': ['gini']
    }),
    "Gradient Boosting": (GradientBoostingClassifier(random_state=42), {
        'n_estimators': [100], # Fewer options
        'learning_rate': [0.05, 0.1],
        'max_depth': [3, 5],
        'subsample': [0.8, 1.0],
        'min_samples_split': [2],
        'min_samples_leaf': [1],
        'max_features': ['sqrt']
    }),
    "XGBoost": (XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42, tree_method='hist'), { # Removed device='cuda' for broader compatibility
        'n_estimators': [100, 200],
        'max_depth': [3, 6],
        'learning_rate': [0.05, 0.1],
        'subsample': [0.8, 1.0],
        'colsample_bytree': [0.8, 1.0],
        'gamma': [0, 0.2],
        'reg_alpha': [0, 0.1],
        'reg_lambda': [0.1, 1]
    }),
  
    'SVM': (SVC(probability=True, random_state=42), {
        'C': [0.1, 1], # Reduced range
        'kernel': ['rbf'], # Focus on most common kernel
        'gamma': ['scale', 0.1],
        'shrinking': [True],
        'class_weight': [None] # Reduced option
    }),
    "Gaussian Naive Bayes": (GaussianNB(), {
        'var_smoothing': [1e-9, 1e-7] # Reduced options
    }),
  
    "Bernoulli Naive Bayes": (BernoulliNB(), {
        'alpha': [0.1, 1.0],
        'binarize': [0.0], # A typical binarization threshold
        'fit_prior': [True, False]
    })
  
}


print("\n--- Combining COH, DPLI, pearson_r and MI features and names for each subject ---")
combined_all_features_by_subject = {}
combined_all_labels_by_subject = {}

all_subject_ids_sets = []
for modality in ['coh', 'dpli', 'mi', 'pearson_r']:
    if modality in analysis_results and 'combined_features_by_subject' in analysis_results[modality]:
        all_subject_ids_sets.append(set(analysis_results[modality]['combined_features_by_subject'].keys()))

if not all_subject_ids_sets:
    print("Error: No connectivity results found in 'analysis_results'. Please ensure previous steps ran successfully.")
    # If analysis_results is truly empty, the rest of the script will likely fail gracefully with warnings/empty results.
    # For a runnable script, we'll proceed assuming dummy data is enough.

subject_ids = sorted(list(set.union(*all_subject_ids_sets)))

# Get feature names for COH, DPLI, MI, and pearson_r
coh_feature_names = analysis_results.get('coh', {}).get('all_feature_names', [])
dpli_feature_names = analysis_results.get('dpli', {}).get('all_feature_names', [])
mi_feature_names = analysis_results.get('mi', {}).get('all_feature_names', [])
pearson_r_feature_names = analysis_results.get('pearson_r', {}).get('all_feature_names', [])

# Create the combined feature names list
combined_all_feature_names = coh_feature_names + dpli_feature_names + mi_feature_names + pearson_r_feature_names

# Determine max_epochs for padding zero arrays if any modality is missing for a subject.
max_epochs = 0
for sub_id in subject_ids:
    for modality in ['coh', 'dpli', 'mi', 'pearson_r']:
        if modality in analysis_results and sub_id in analysis_results[modality].get('combined_features_by_subject', {}):
            num_epochs = analysis_results[modality]['combined_features_by_subject'][sub_id].shape[0]
            if num_epochs > max_epochs:
                max_epochs = num_epochs

if max_epochs == 0:
    print("Warning: No epochs found across any subject or modality. Classification might fail.")
    # If no epochs, padding with zeros might be problematic or lead to empty arrays later.
    # For robust demonstration, let's ensure max_epochs is at least 1 if there are any features.
    if any(len(feat_names_list) > 0 for feat_names_list in [coh_feature_names, dpli_feature_names, mi_feature_names, pearson_r_feature_names]):
        print("Forcing max_epochs to 1 as no epochs found but feature names exist.")
        max_epochs = 1 # Minimal epoch for the sake of array creation if truly no data.

for subject_id in subject_ids:
    print(f"Processing subject {subject_id} for combined features...")
    subject_features_to_stack = []
    subject_labels = None # Labels should ideally be consistent across modalities for a subject

    for modality, feat_names_list in zip(['coh', 'dpli', 'mi', 'pearson_r'],
                                         [coh_feature_names, dpli_feature_names, mi_feature_names, pearson_r_feature_names]):
        
        features_for_modality = analysis_results.get(modality, {}).get('combined_features_by_subject', {}).get(subject_id)
        labels_for_modality = analysis_results.get(modality, {}).get('combined_labels_by_subject', {}).get(subject_id)

        if features_for_modality is not None and features_for_modality.shape[0] > 0:
            # Ensure features have correct number of columns
            if features_for_modality.shape[1] != len(feat_names_list):
                print(f"  Warning: Feature count mismatch for {modality} of {subject_id}. Expected {len(feat_names_list)}, got {features_for_modality.shape[1]}. Skipping or padding might occur.")
                # Attempt to pad/truncate if mismatch, or raise error
                # For robustness, we'll try to align to the expected number of features
                if features_for_modality.shape[1] > len(feat_names_list):
                    features_for_modality = features_for_modality[:, :len(feat_names_list)]
                else: # Pad with zeros if fewer features than expected
                    temp_features = np.zeros((features_for_modality.shape[0], len(feat_names_list)))
                    temp_features[:, :features_for_modality.shape[1]] = features_for_modality
                    features_for_modality = temp_features

            # Pad or truncate features to max_epochs for consistent stacking
            if features_for_modality.shape[0] < max_epochs:
                padded_features = np.zeros((max_epochs, features_for_modality.shape[1]))
                padded_features[:features_for_modality.shape[0], :] = features_for_modality
                subject_features_to_stack.append(padded_features)
            elif features_for_modality.shape[0] > max_epochs:
                subject_features_to_stack.append(features_for_modality[:max_epochs, :])
            else:
                subject_features_to_stack.append(features_for_modality)

            if subject_labels is None: # Only set labels if not already set by a previous modality
                if labels_for_modality.shape[0] < max_epochs:
                    padded_labels = np.zeros(max_epochs, dtype=int) # Assuming 0 is a valid padding for labels
                    padded_labels[:labels_for_modality.shape[0]] = labels_for_modality
                    subject_labels = padded_labels
                elif labels_for_modality.shape[0] > max_epochs:
                    subject_labels = labels_for_modality[:max_epochs]
                else:
                    subject_labels = labels_for_modality
            # Optional: Add an else block here to check for label consistency if labels already set
            elif not np.array_equal(subject_labels, labels_for_modality[:max_epochs]): # Compare truncated labels
                print(f"  Warning: Labels for subject {subject_id} differ between modalities. Using the first available labels.")
        else:
            print(f"  Warning: {modality.upper()} data missing or empty for subject {subject_id}. Appending zeros.")
            subject_features_to_stack.append(np.zeros((max_epochs, len(feat_names_list))))
            if subject_labels is None: # Only set labels if not already set by a previous modality
                subject_labels = np.zeros(max_epochs, dtype=int) # Default label for missing epochs

    # Final combination for the subject
    if not subject_features_to_stack or (len(subject_features_to_stack) > 0 and sum(f.shape[1] for f in subject_features_to_stack) == 0):
        print(f"Skipping subject {subject_id}: No valid features to combine across modalities.")
        continue

    try:
        combined_features = np.hstack(subject_features_to_stack)
        if subject_labels is None:
            raise ValueError("Labels could not be determined for subject after combining features.")

        combined_all_features_by_subject[subject_id] = combined_features
        combined_all_labels_by_subject[subject_id] = subject_labels
        print(f"  Subject {subject_id}: Combined features shape {combined_features.shape}")
    except ValueError as e:
        print(f"Error combining features for subject {subject_id}: {e}. This usually means feature arrays have inconsistent numbers of samples (rows).")
        print(f"  Shapes of features to stack: {[f.shape for f in subject_features_to_stack]}")
        continue

print(f"\n==================================================")
print(f"🔄 INTER-SUBJECT CLASSIFICATION (Combined COH+DPLI+MI+pearson_r Features - LOSO)")
print(f"==================================================")

loso_accuracies_summary = {name: [] for name in models_and_params}
loso_f1_summary = {name: [] for name in models_and_params}
loso_roc_auc_summary = {name: [] for name in models_and_params}

# Store FPRs and TPRs for plotting average ROC curves
all_fprs = defaultdict(list)
all_tprs = defaultdict(list)
# all_aucs = defaultdict(list) # This is already stored in loso_roc_auc_summary

actual_subjects_with_combined_data = list(combined_all_features_by_subject.keys())

if not actual_subjects_with_combined_data:
    print("No subjects with valid combined data for LOSO. Exiting classification.")
    # If no subjects, the rest of the script will print summaries based on empty lists.

for test_subject in actual_subjects_with_combined_data:
    X_test_raw = combined_all_features_by_subject[test_subject]
    y_test = combined_all_labels_by_subject[test_subject]

    X_train_list = []
    y_train_list = []
    for s in actual_subjects_with_combined_data:
        if s != test_subject:
            X_train_list.append(combined_all_features_by_subject[s])
            y_train_list.append(combined_all_labels_by_subject[s])
    
    if not X_train_list:
        print(f"Skipping classification for {test_subject}: No training data available.")
        for name in models_and_params:
            loso_accuracies_summary[name].append(np.nan)
            loso_f1_summary[name].append(np.nan)
            loso_roc_auc_summary[name].append(np.nan)
        continue

    X_train_raw = np.vstack(X_train_list)
    y_train = np.hstack(y_train_list)

    print(f"\n--- Testing on Subject {test_subject} (Combined COH+DPLI+MI Features) ---")
    print(f"  Train samples: {X_train_raw.shape[0]}, Test samples: {X_test_raw.shape[0]}")
    print(f"  Train features: {X_train_raw.shape[1]}, Test features: {X_test_raw.shape[1]}")

    if X_train_raw.shape[0] == 0 or X_test_raw.shape[0] == 0:
        print(f"Skipping classification for {test_subject}: Empty train or test set.")
        for name in models_and_params:
            loso_accuracies_summary[name].append(np.nan)
            loso_f1_summary[name].append(np.nan)
            loso_roc_auc_summary[name].append(np.nan)
        continue
    
    if X_train_raw.shape[1] == 0:
        print(f"Skipping classification for {test_subject}: No features available after initial combining (all zeros/empty).")
        for name in models_and_params:
            loso_accuracies_summary[name].append(np.nan)
            loso_f1_summary[name].append(np.nan)
            loso_roc_auc_summary[name].append(np.nan)
        continue

    # --- Start of Proper Preprocessing Pipeline within LOSO Fold ---
    
    # Imputation (fit on train, transform on both)
    imputer = SimpleImputer(strategy='mean')
    X_train_imputed = imputer.fit_transform(X_train_raw)
    X_test_imputed = imputer.transform(X_test_raw)

    # Z-score Normalization (fit on train, transform on both)
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train_imputed)
    X_test_scaled = scaler.transform(X_test_imputed)

    # Variance Thresholding (fit on train, transform on both)
    var_thresh = VarianceThreshold(threshold=1e-5)
    X_train_f = var_thresh.fit_transform(X_train_scaled)
    X_test_f = var_thresh.transform(X_test_scaled)
    original_indices_after_var_thresh = var_thresh.get_support(indices=True)
    
    if X_train_f.shape[1] == 0:
        print(f"  No features remaining after variance threshold for {test_subject}. Skipping models.")
        for name in models_and_params:
            loso_accuracies_summary[name].append(np.nan)
            loso_f1_summary[name].append(np.nan)
            loso_roc_auc_summary[name].append(np.nan)
        continue

    for name, (model, param_grid) in models_and_params.items():
        # Apply t-test feature selection (fit on train, transform on both)
        X_train_sel, X_test_sel, selected_relative_indices = ttest_feature_selection(X_train_f, y_train, X_test_f)

        # Map selected indices back to the original combined feature names for printing
        selected_feature_names = np.array([])
        if len(selected_relative_indices) > 0 and len(original_indices_after_var_thresh) > 0:
            selected_global_indices = original_indices_after_var_thresh[selected_relative_indices]
            selected_feature_names = np.array(combined_all_feature_names)[selected_global_indices]

        print(f"  Model: {name}")
        print(f"    Selected {len(selected_feature_names)} features out of {X_train_raw.shape[1]} original combined features.")
        if len(selected_feature_names) > 0:
            print(f"    Selected Features: {', '.join(selected_feature_names[:10])}{'...' if len(selected_feature_names) > 10 else ''}") # Print only first 10 for brevity
        else:
            print("    No features selected.")

        if X_train_sel.shape[1] == 0:
            print(f"    No features remaining after t-test selection. Skipping classification for this model.")
            loso_accuracies_summary[name].append(np.nan)
            loso_f1_summary[name].append(np.nan)
            loso_roc_auc_summary[name].append(np.nan)
            continue
        
        unique_classes_train = np.unique(y_train)
        if X_train_sel.shape[0] < 3 or len(unique_classes_train) < 2:
            print(f"    Not enough samples ({X_train_sel.shape[0]}) or classes ({len(unique_classes_train)}) in training set for RandomizedSearchCV. Skipping classification.")
            loso_accuracies_summary[name].append(np.nan)
            loso_f1_summary[name].append(np.nan)
            loso_roc_auc_summary[name].append(np.nan)
            continue

        # Grid search for best hyperparameters
        try:
            if param_grid:
                # Using RandomizedSearchCV for faster tuning
                # n_iter: Number of parameter settings that are sampled. Reduce for faster runs.
                clf = RandomizedSearchCV(model, param_grid, n_iter=min(20, len(list(model.get_params().keys())) * 2), cv=3, scoring='accuracy', n_jobs=-1, verbose=0, random_state=42) # Adjust n_iter based on param grid size
                clf.fit(X_train_sel, y_train)
                best_estimator = clf.best_estimator_
                best_params = clf.best_params_
            else: # For models like GaussianNB with no parameters to tune
                best_estimator = model
                best_estimator.fit(X_train_sel, y_train)
                best_params = 'N/A'

            y_pred = best_estimator.predict(X_test_sel)
            acc = accuracy_score(y_test, y_pred) * 100
            f1 = f1_score(y_test, y_pred) # Calculate F1-score

            loso_accuracies_summary[name].append(acc)
            loso_f1_summary[name].append(f1) # Store F1-score

            print(f"    Accuracy: {acc:.2f}%")
            print(f"    F1-score: {f1:.2f}") # Print F1-score
            print(f"    Best Params: {best_params}")
            if np.unique(y_test).size == 2 and y_test.shape[0] >= 2:
                print(f"    Confusion Matrix:\n{confusion_matrix(y_test, y_pred)}")

                # Calculate ROC AUC and store FPR/TPR for plotting
                if hasattr(best_estimator, "predict_proba"):
                    y_prob = best_estimator.predict_proba(X_test_sel)[:, 1]
                    roc_auc = roc_auc_score(y_test, y_prob)
                    loso_roc_auc_summary[name].append(roc_auc)
                    print(f"    ROC AUC: {roc_auc:.2f}")

                    # Store FPR/TPR for average ROC curve
                    fpr, tpr, _ = roc_curve(y_test, y_prob)
                    all_fprs[name].append(fpr)
                    all_tprs[name].append(tpr)
                else:
                    loso_roc_auc_summary[name].append(np.nan)
                    print("    Model does not support predict_proba, skipping ROC AUC.")
            else:
                loso_roc_auc_summary[name].append(np.nan)
                print("    Cannot compute confusion matrix or ROC (not enough unique classes or samples in test set).")
        except Exception as e:
            print(f"    Error during classification for {name}: {e}. Skipping.")
            loso_accuracies_summary[name].append(np.nan)
            loso_f1_summary[name].append(np.nan)
            loso_roc_auc_summary[name].append(np.nan)
            continue

# Store the summary results for the combined features
analysis_results['combined_coh_dpli_mi_pearson_r'] = { # Changed key to include pearson_r
    'loso_accuracies_summary': loso_accuracies_summary,
    'loso_f1_summary': loso_f1_summary,
    'loso_roc_auc_summary': loso_roc_auc_summary
}


print(f"\n==================================================")
print(f"📊 OVERALL LOSO ACCURACY, F1-SCORE, and ROC AUC SUMMARY (Combined COH+DPLI+MI+pearson_r)")
print(f"==================================================")
for name, accs in loso_accuracies_summary.items():
    valid_accs = [a for a in accs if not np.isnan(a)]
    valid_f1s = [f for f in loso_f1_summary[name] if not np.isnan(f)]
    valid_aucs = [a for a in loso_roc_auc_summary[name] if not np.isnan(a)]

    if valid_accs:
        print(f"{name}:")
        print(f"  Accuracy: {np.mean(valid_accs):.2f}% ± {np.std(valid_accs):.2f}%")
        if valid_f1s:
            print(f"  F1-score: {np.mean(valid_f1s):.2f} ± {np.std(valid_f1s):.2f}")
        else:
            print(f"  F1-score: No valid F1-scores to report.")
        if valid_aucs:
            print(f"  ROC AUC: {np.mean(valid_aucs):.2f} ± {np.std(valid_aucs):.2f}")
        else:
            print(f"  ROC AUC: No valid ROC AUCs to report.")
    else:
        print(f"{name}: No valid metrics to report.")

print("\n--- Finished classification for combined connectivity types ---")





In [None]:
print(f"\n==================================================")
print(f"📈 AVERAGE ROC CURVES (Combined COH+DPLI+MI+pearson_r Features)")
print(f"==================================================")

# Create the main ROC curve plot
fig_roc, ax_roc = plt.subplots(figsize=(8, 5))
ax_roc.set_title('Average ROC Curve for Each Model (LOSO Cross-Validation)')
ax_roc.set_xlabel('False Positive Rate')
ax_roc.set_ylabel('True Positive Rate')
ax_roc.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)

mean_fpr = np.linspace(0, 1, 100)

# Collect handles and labels for the legend
handles = []
labels = []

# Add the 'Chance' line to the legend handles and labels
handles.append(ax_roc.lines[-1]) # Get the last line added, which is 'Chance'
labels.append('Chance')

for name in models_and_params:
    if all_fprs[name]: # Check if there is any ROC data for the model
        tprs_interp = []
        aucs_for_plot_std = [] # Collect AUCs for std dev calculation in plot legend
        
        # Ensure we only iterate up to the minimum number of folds available for this model
        min_folds = min(len(all_fprs[name]), len(all_tprs[name]))

        for i in range(min_folds):
            # Interpolate all ROC curves to the common mean_fpr
            tprs_interp.append(np.interp(mean_fpr, all_fprs[name][i], all_tprs[name][i]))
            tprs_interp[-1][0] = 0.0 # Ensure the curve starts at (0,0)

            # Get the AUC for the current fold for std dev calculation
            # Use the already stored AUCs from loso_roc_auc_summary for consistency
            if not np.isnan(loso_roc_auc_summary[name][i]):
                aucs_for_plot_std.append(loso_roc_auc_summary[name][i])


        if tprs_interp: # Check if there are any interpolated TPRs
            mean_tpr = np.mean(tprs_interp, axis=0)
            mean_tpr[-1] = 1.0 # Ensure the curve ends at (1,1)
            mean_auc = auc(mean_fpr, mean_tpr)
            
            # Calculate standard deviation of AUCs if available (though not used in legend label anymore)
            std_auc_val = np.std(aucs_for_plot_std) if aucs_for_plot_std else 0.0 # Default to 0 if no valid AUCs

            line, = ax_roc.plot(mean_fpr, mean_tpr,
                                # Removed std_auc_val from the label
                                label=r'Mean %s ROC (AUC = %0.2f)' % (name, mean_auc),
                                lw=2, alpha=.8)
            handles.append(line)
            # Removed std_auc_val from the label for the handles list as well
            labels.append(r'Mean %s ROC (AUC = %0.2f)' % (name, mean_auc))


            std_tpr = np.std(tprs_interp, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
            # Removed the fill_between_plot from the legend
            ax_roc.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2)
            
        else:
            print(f"No valid interpolated TPRs to plot for {name} after processing folds.")
    else:
        print(f"No ROC data collected for {name}. Skipping plotting for this model.")

ax_roc.grid(True)
plt.show() # Display the main ROC plot

# Create a separate plot for the legend
fig_legend = plt.figure(figsize=(6, len(models_and_params) * 0.75 + 1)) # Adjust size based on number of models
ax_legend = fig_legend.add_subplot(111)
ax_legend.legend(handles, labels, loc='center', frameon=False, prop={'size': 10}) # Make text smaller
ax_legend.axis('off') # Hide the axes
plt.show() # Display the legend plot


import numpy as np
import matplotlib.pyplot as plt

# Example: loso_accuracies_summary must be defined already as a dictionary
# loso_accuracies_summary = {'SVM': [...], 'MLP': [...], ...}

# Remove Bernoulli Naive Bayes if present
model_names = [model for model in loso_accuracies_summary.keys() if model != "Bernoulli Naive Bayes"]
average_accuracies = [
    np.nanmean(loso_accuracies_summary[model]) if len(loso_accuracies_summary[model]) > 0 else 0
    for model in model_names
]

# Convert to numpy array
average_accuracies = np.array(average_accuracies)

# Radar plot setup
num_vars = len(model_names)
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()

# Close the radar plot circle
average_accuracies = np.concatenate((average_accuracies, [average_accuracies[0]]))
angles += [angles[0]]

# Colormap for unique colors per model
colormap = plt.colormaps['tab10'] if num_vars <= 10 else plt.colormaps['hsv']
colors = [colormap(i / num_vars) for i in range(num_vars)]

# Plotting
fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True))

# Plot lines and fill (single main line for average)
ax.plot(angles, average_accuracies, color='black', linewidth=2, label='Accuracy Curve')
ax.fill(angles, average_accuracies, color='lightgray', alpha=0.2)

# Add each model point with a different color
for i in range(num_vars):
    ax.plot(angles[i], average_accuracies[i], 'o', color=colors[i], markersize=10)
    ax.text(angles[i], average_accuracies[i] + 5, f'{average_accuracies[i]:.1f}%', 
            horizontalalignment='center', verticalalignment='bottom', fontsize=12, color=colors[i])

# Axis labels and styling
ax.set_xticks(angles[:-1])
ax.set_xticklabels(model_names, fontsize=13, fontweight='bold')

ax.set_title('Average LOSO Accuracy per Classifier', size=18, y=1.1, fontweight='bold')

# Radial labels
ax.set_rlabel_position(180 / num_vars)
ax.set_yticks([70, 75, 80, 85, 90, 95, 100])
ax.set_yticklabels(["70%", "75%", "80%", "85%", "90%", "95%", "100%"], fontsize=12)
ax.set_ylim(50, 100)

plt.tight_layout()
plt.show()
fig.savefig("model_accuracy_radar_colored_points.png", dpi=600)

print("\n--- Plotting finished ---")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier,
                              AdaBoostClassifier, ExtraTreesClassifier)
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB, MultinomialNB, BernoulliNB, ComplementNB
from xgboost import XGBClassifier
from sklearn.tree import DecisionTreeClassifier

from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, roc_auc_score, roc_curve, auc
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from scipy.stats import ttest_ind
import warnings
import pandas as pd
from collections import defaultdict

warnings.filterwarnings('ignore')

def ttest_feature_selection(X_train, y_train, X_test, alpha=0.05):
    """
    Performs feature selection using independent t-tests.
    Selects features where the p-value is below the given alpha threshold.
    Includes a fallback to select top 10 features if none pass the threshold.
    """
    X_att = X_train[y_train == 1]
    X_inatt = X_train[y_train == 0]

    # Handle cases where one class might be empty or not enough samples for t-test
    if X_att.shape[0] < 2 or X_inatt.shape[0] < 2:
        print("Warning: Not enough samples in one or both classes for t-test. Skipping t-test based selection, returning all features.")
        if X_train.shape[1] > 0:
            return X_train, X_test, np.arange(X_train.shape[1])
        else: # No features to begin with
            return np.array([]).reshape(X_train.shape[0], 0), \
                   np.array([]).reshape(X_test.shape[0], 0), \
                   np.array([])

    if X_train.shape[1] == 0:
        print("Warning: No features in input to t-test selection. Returning empty arrays.")
        return np.array([]).reshape(X_train.shape[0], 0), \
               np.array([]).reshape(X_test.shape[0], 0), \
               np.array([])

    t_vals, p_vals = ttest_ind(X_att, X_inatt, axis=0, equal_var=False)
    selected_indices = np.where(p_vals < alpha)[0]

    # Fallback if no features pass threshold or if t-test was skipped
    if len(selected_indices) == 0:
        print(f"Warning: No features passed t-test threshold (alpha={alpha}). Selecting top 10 features by p-value.")
        if p_vals.size > 0: # Ensure p_vals is not empty before sorting
            selected_indices = np.argsort(p_vals)[:min(10, p_vals.size)] # Select top 10 or all if less than 10
        else: # If no p-values (e.g., no features to begin with)
            selected_indices = np.array([])
        
        # Final fallback to select all features if nothing else worked and features exist
        if len(selected_indices) == 0 and X_train.shape[1] > 0:
            selected_indices = np.arange(X_train.shape[1])

    if X_train.shape[1] == 0 or len(selected_indices) == 0:
        print("Warning: No features left after t-test selection. Returning empty arrays.")
        return np.array([]).reshape(X_train.shape[0], 0), \
               np.array([]).reshape(X_test.shape[0], 0), \
               np.array([])

    return X_train[:, selected_indices], X_test[:, selected_indices], selected_indices

# --- Models and hyperparameter grids (Optimized for speed) ---
models_and_params = {
    'LogisticRegression': (LogisticRegression(max_iter=2000, random_state=42), {
        'C': [0.001, 0.01, 0.1, 1], # Reduced range
        'penalty': ['l1', 'l2', 'elasticnet'],
        'solver': ['saga'], # Focus on saga for elasticnet
        'l1_ratio': [0.1, 0.5, 0.9], # For elasticnet
        'max_iter': [1000] # Single value, good enough if convergence is met
    }),
    "SGD Classifier": (SGDClassifier(random_state=42), {
        'loss': ['hinge', 'log_loss', 'modified_huber'],
        'penalty': ['l1', 'l2', 'elasticnet'],
        'alpha': [0.0001, 0.001, 0.01], # Reduced range
        'l1_ratio': [0.3, 0.7],
        'learning_rate': ['constant', 'adaptive'],
        'eta0': [0.01, 0.1]
    }),
    'MLP': (MLPClassifier(max_iter=2000, random_state=42), {
        'hidden_layer_sizes': [
            (100,), (150,), # Single layer
            (100, 50) # Two layers
        ],
        'activation': ['relu', 'tanh'],
        'solver': ['adam'], # Adam is generally faster
        'alpha': [0.001, 0.01],
        'learning_rate': ['constant', 'adaptive'],
        'learning_rate_init': [0.0001, 0.001],
    }),
    'RandomForest': (RandomForestClassifier(random_state=42), {
        'n_estimators': [100, 200], # Fewer options
        'max_depth': [None, 10, 20],
        'min_samples_split': [2, 10],
        'min_samples_leaf': [1, 4],
        'max_features': ['sqrt', None], # Simplified options
        'bootstrap': [True],
        'criterion': ['gini']
    }),
    "Extra Trees": (ExtraTreesClassifier(random_state=42), {
        'n_estimators': [100, 200],
        'max_depth': [None, 20],
        'min_samples_split': [2, 10],
        'min_samples_leaf': [1, 4],
        'max_features': ['sqrt', None],
        'criterion': ['gini']
    }),
    "Gradient Boosting": (GradientBoostingClassifier(random_state=42), {
        'n_estimators': [100], # Fewer options
        'learning_rate': [0.05, 0.1],
        'max_depth': [3, 5],
        'subsample': [0.8, 1.0],
        'min_samples_split': [2],
        'min_samples_leaf': [1],
        'max_features': ['sqrt']
    }),
    "XGBoost": (XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42, tree_method='hist'), {
        'n_estimators': [100, 200],
        'max_depth': [3, 6],
        'learning_rate': [0.05, 0.1],
        'subsample': [0.8, 1.0],
        'colsample_bytree': [0.8, 1.0],
        'gamma': [0, 0.2],
        'reg_alpha': [0, 0.1],
        'reg_lambda': [0.1, 1]
    }),
    'SVM': (SVC(probability=True, random_state=42), {
        'C': [0.1, 1], # Reduced range
        'kernel': ['rbf'], # Focus on most common kernel
        'gamma': ['scale', 0.1],
        'shrinking': [True],
        'class_weight': [None] # Reduced option
    }),
    "Gaussian Naive Bayes": (GaussianNB(), {
        'var_smoothing': [1e-9, 1e-7] # Reduced options
    }),
    "Bernoulli Naive Bayes": (BernoulliNB(), {
        'alpha': [0.1, 1.0],
        'binarize': [0.0], # A typical binarization threshold
        'fit_prior': [True, False]
    })
}

print("\n--- Combining COH, DPLI, pearson_r and MI features and names for each subject ---")
combined_all_features_by_subject = {}
combined_all_labels_by_subject = {}

all_subject_ids_sets = []
for modality in ['coh', 'dpli', 'mi', 'pearson_r']:
    if modality in analysis_results and 'combined_features_by_subject' in analysis_results[modality]:
        all_subject_ids_sets.append(set(analysis_results[modality]['combined_features_by_subject'].keys()))

if not all_subject_ids_sets:
    print("Error: No connectivity results found in 'analysis_results'. Please ensure previous steps ran successfully.")
    # If analysis_results is truly empty, the rest of the script will likely fail gracefully with warnings/empty results.
    # For a runnable script, we'll proceed assuming dummy data is enough.

subject_ids = sorted(list(set.union(*all_subject_ids_sets)))

# Get feature names for COH, DPLI, MI, and pearson_r
coh_feature_names = analysis_results.get('coh', {}).get('all_feature_names', [])
dpli_feature_names = analysis_results.get('dpli', {}).get('all_feature_names', [])
mi_feature_names = analysis_results.get('mi', {}).get('all_feature_names', [])
pearson_r_feature_names = analysis_results.get('pearson_r', {}).get('all_feature_names', [])

# Create the combined feature names list
combined_all_feature_names = coh_feature_names + dpli_feature_names + mi_feature_names + pearson_r_feature_names

# Determine max_epochs for padding zero arrays if any modality is missing for a subject.
max_epochs = 0
for sub_id in subject_ids:
    for modality in ['coh', 'dpli', 'mi', 'pearson_r']:
        if modality in analysis_results and sub_id in analysis_results[modality].get('combined_features_by_subject', {}):
            num_epochs = analysis_results[modality]['combined_features_by_subject'][sub_id].shape[0]
            if num_epochs > max_epochs:
                max_epochs = num_epochs

if max_epochs == 0:
    print("Warning: No epochs found across any subject or modality. Classification might fail.")
    # If no epochs, padding with zeros might be problematic or lead to empty arrays later.
    # For robust demonstration, let's ensure max_epochs is at least 1 if there are any features.
    if any(len(feat_names_list) > 0 for feat_names_list in [coh_feature_names, dpli_feature_names, mi_feature_names, pearson_r_feature_names]):
        print("Forcing max_epochs to 1 as no epochs found but feature names exist.")
        max_epochs = 1 # Minimal epoch for the sake of array creation if truly no data.

for subject_id in subject_ids:
    print(f"Processing subject {subject_id} for combined features...")
    subject_features_to_stack = []
    subject_labels = None # Labels should ideally be consistent across modalities for a subject

    for modality, feat_names_list in zip(['coh', 'dpli', 'mi', 'pearson_r'],
                                         [coh_feature_names, dpli_feature_names, mi_feature_names, pearson_r_feature_names]):
        
        features_for_modality = analysis_results.get(modality, {}).get('combined_features_by_subject', {}).get(subject_id)
        labels_for_modality = analysis_results.get(modality, {}).get('combined_labels_by_subject', {}).get(subject_id)

        if features_for_modality is not None and features_for_modality.shape[0] > 0:
            # Ensure features have correct number of columns
            if features_for_modality.shape[1] != len(feat_names_list):
                print(f"  Warning: Feature count mismatch for {modality} of {subject_id}. Expected {len(feat_names_list)}, got {features_for_modality.shape[1]}. Skipping or padding might occur.")
                # Attempt to pad/truncate if mismatch, or raise error
                # For robustness, we'll try to align to the expected number of features
                if features_for_modality.shape[1] > len(feat_names_list):
                    features_for_modality = features_for_modality[:, :len(feat_names_list)]
                else: # Pad with zeros if fewer features than expected
                    temp_features = np.zeros((features_for_modality.shape[0], len(feat_names_list)))
                    temp_features[:, :features_for_modality.shape[1]] = features_for_modality
                    features_for_modality = temp_features

            # Pad or truncate features to max_epochs for consistent stacking
            if features_for_modality.shape[0] < max_epochs:
                padded_features = np.zeros((max_epochs, features_for_modality.shape[1]))
                padded_features[:features_for_modality.shape[0], :] = features_for_modality
                subject_features_to_stack.append(padded_features)
            elif features_for_modality.shape[0] > max_epochs:
                subject_features_to_stack.append(features_for_modality[:max_epochs, :])
            else:
                subject_features_to_stack.append(features_for_modality)

            if subject_labels is None: # Only set labels if not already set by a previous modality
                if labels_for_modality.shape[0] < max_epochs:
                    padded_labels = np.zeros(max_epochs, dtype=int) # Assuming 0 is a valid padding for labels
                    padded_labels[:labels_for_modality.shape[0]] = labels_for_modality
                    subject_labels = padded_labels
                elif labels_for_modality.shape[0] > max_epochs:
                    subject_labels = labels_for_modality[:max_epochs]
                else:
                    subject_labels = labels_for_modality
            # Optional: Add an else block here to check for label consistency if labels already set
            elif not np.array_equal(subject_labels, labels_for_modality[:max_epochs]): # Compare truncated labels
                print(f"  Warning: Labels for subject {subject_id} differ between modalities. Using the first available labels.")
        else:
            print(f"  Warning: {modality.upper()} data missing or empty for subject {subject_id}. Appending zeros.")
            subject_features_to_stack.append(np.zeros((max_epochs, len(feat_names_list))))
            if subject_labels is None: # Only set labels if not already set by a previous modality
                subject_labels = np.zeros(max_epochs, dtype=int) # Default label for missing epochs

    # Final combination for the subject
    if not subject_features_to_stack or (len(subject_features_to_stack) > 0 and sum(f.shape[1] for f in subject_features_to_stack) == 0):
        print(f"Skipping subject {subject_id}: No valid features to combine across modalities.")
        continue

    try:
        combined_features = np.hstack(subject_features_to_stack)
        if subject_labels is None:
            raise ValueError("Labels could not be determined for subject after combining features.")

        combined_all_features_by_subject[subject_id] = combined_features
        combined_all_labels_by_subject[subject_id] = subject_labels
        print(f"  Subject {subject_id}: Combined features shape {combined_features.shape}")
    except ValueError as e:
        print(f"Error combining features for subject {subject_id}: {e}. This usually means feature arrays have inconsistent numbers of samples (rows).")
        print(f"  Shapes of features to stack: {[f.shape for f in subject_features_to_stack]}")
        continue

print(f"\n==================================================")
print(f"🔄 SUBJECT-DEPENDENT CLASSIFICATION (Combined COH+DPLI+MI+pearson_r Features)")
print(f"==================================================")

# Initialize dictionaries to store results for each subject
subject_results = {}

# Store FPRs and TPRs for plotting average ROC curves
all_fprs = defaultdict(list)
all_tprs = defaultdict(list)

# Number of folds for cross-validation
n_folds = 2

for subject_id in subject_ids:
    print(f"\n--- Processing Subject {subject_id} ---")
    
    # Get the data for this subject
    X = combined_all_features_by_subject[subject_id]
    y = combined_all_labels_by_subject[subject_id]
    
    # Skip if no data or not enough samples
    if X.shape[0] == 0 or len(np.unique(y)) < 2:
        print(f"  Skipping subject {subject_id}: Not enough data or only one class present.")
        continue
    
    # Initialize dictionaries to store results for this subject
    subject_results[subject_id] = {
        'accuracies': {name: [] for name in models_and_params},
        'f1_scores': {name: [] for name in models_and_params},
        'roc_aucs': {name: [] for name in models_and_params}
    }
    
    # Create cross-validation folds
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
    
    for fold_idx, (train_idx, test_idx) in enumerate(skf.split(X, y)):
        print(f"\n  Fold {fold_idx + 1}/{n_folds}")
        X_train_raw, X_test_raw = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        print(f"    Train samples: {X_train_raw.shape[0]}, Test samples: {X_test_raw.shape[0]}")
        print(f"    Features: {X_train_raw.shape[1]}")
        
        if X_train_raw.shape[0] == 0 or X_test_raw.shape[0] == 0:
            print("    Skipping fold: Empty train or test set.")
            continue
            
        if X_train_raw.shape[1] == 0:
            print("    Skipping fold: No features available.")
            continue
        
        # --- Preprocessing Pipeline ---
        # Imputation (fit on train, transform on both)
        imputer = SimpleImputer(strategy='mean')
        X_train_imputed = imputer.fit_transform(X_train_raw)
        X_test_imputed = imputer.transform(X_test_raw)

        # Z-score Normalization (fit on train, transform on both)
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train_imputed)
        X_test_scaled = scaler.transform(X_test_imputed)

        # Variance Thresholding (fit on train, transform on both)
        var_thresh = VarianceThreshold(threshold=1e-5)
        X_train_f = var_thresh.fit_transform(X_train_scaled)
        X_test_f = var_thresh.transform(X_test_scaled)
        original_indices_after_var_thresh = var_thresh.get_support(indices=True)
        
        if X_train_f.shape[1] == 0:
            print("    No features remaining after variance threshold. Skipping models.")
            continue

        for name, (model, param_grid) in models_and_params.items():
            print(f"    Model: {name}")
            
            # Apply t-test feature selection (fit on train, transform on both)
            X_train_sel, X_test_sel, selected_relative_indices = ttest_feature_selection(X_train_f, y_train, X_test_f)

            # Map selected indices back to the original combined feature names for printing
            selected_feature_names = np.array([])
            if len(selected_relative_indices) > 0 and len(original_indices_after_var_thresh) > 0:
                selected_global_indices = original_indices_after_var_thresh[selected_relative_indices]
                selected_feature_names = np.array(combined_all_feature_names)[selected_global_indices]

            print(f"      Selected {len(selected_feature_names)} features out of {X_train_raw.shape[1]} original combined features.")
            if len(selected_feature_names) > 0:
                print(f"      Selected Features: {', '.join(selected_feature_names[:10])}{'...' if len(selected_feature_names) > 10 else ''}")
            else:
                print("      No features selected.")

            if X_train_sel.shape[1] == 0:
                print("      No features remaining after t-test selection. Skipping.")
                continue
            
            unique_classes_train = np.unique(y_train)
            if X_train_sel.shape[0] < 3 or len(unique_classes_train) < 2:
                print("      Not enough samples or classes in training set. Skipping.")
                continue

            # Grid search for best hyperparameters
            try:
                if param_grid:
                    # Using RandomizedSearchCV for faster tuning
                    clf = RandomizedSearchCV(model, param_grid, n_iter=min(20, len(list(model.get_params().keys())) * 2), 
                                           cv=2, scoring='accuracy', n_jobs=-1, verbose=0, random_state=42)
                    clf.fit(X_train_sel, y_train)
                    best_estimator = clf.best_estimator_
                    best_params = clf.best_params_
                else: # For models with no parameters to tune
                    best_estimator = model
                    best_estimator.fit(X_train_sel, y_train)
                    best_params = 'N/A'

                y_pred = best_estimator.predict(X_test_sel)
                acc = accuracy_score(y_test, y_pred) * 100
                f1 = f1_score(y_test, y_pred)

                # Store results for this fold
                subject_results[subject_id]['accuracies'][name].append(acc)
                subject_results[subject_id]['f1_scores'][name].append(f1)

                print(f"      Accuracy: {acc:.2f}%")
                print(f"      F1-score: {f1:.2f}")
                print(f"      Best Params: {best_params}")
                if np.unique(y_test).size == 2 and y_test.shape[0] >= 2:
                    print(f"      Confusion Matrix:\n{confusion_matrix(y_test, y_pred)}")

                    # Calculate ROC AUC and store FPR/TPR for plotting
                    if hasattr(best_estimator, "predict_proba"):
                        y_prob = best_estimator.predict_proba(X_test_sel)[:, 1]
                        roc_auc = roc_auc_score(y_test, y_prob)
                        subject_results[subject_id]['roc_aucs'][name].append(roc_auc)
                        print(f"      ROC AUC: {roc_auc:.2f}")

                        # Store FPR/TPR for average ROC curve
                        fpr, tpr, _ = roc_curve(y_test, y_prob)
                        all_fprs[name].append(fpr)
                        all_tprs[name].append(tpr)
                    else:
                        subject_results[subject_id]['roc_aucs'][name].append(np.nan)
                        print("      Model does not support predict_proba, skipping ROC AUC.")
                else:
                    subject_results[subject_id]['roc_aucs'][name].append(np.nan)
                    print("      Cannot compute confusion matrix or ROC (not enough unique classes or samples in test set).")
            except Exception as e:
                print(f"      Error during classification: {e}. Skipping.")
                continue

# Store the summary results for the combined features
analysis_results['combined_coh_dpli_mi_pearson_r_subject_dependent'] = {
    'subject_results': subject_results
}

print(f"\n==================================================")
print(f"📊 SUBJECT-DEPENDENT CLASSIFICATION SUMMARY")
print(f"==================================================")

# Calculate and print average performance across subjects
for name in models_and_params:
    print(f"\n{name}:")
    
    all_accs = []
    all_f1s = []
    all_aucs = []
    
    for subject_id in subject_ids:
        if subject_id in subject_results:
            accs = [a for a in subject_results[subject_id]['accuracies'][name] if not np.isnan(a)]
            f1s = [f for f in subject_results[subject_id]['f1_scores'][name] if not np.isnan(f)]
            aucs = [a for a in subject_results[subject_id]['roc_aucs'][name] if not np.isnan(a)]
            
            if accs:
                subj_mean_acc = np.mean(accs)
                all_accs.append(subj_mean_acc)
            if f1s:
                subj_mean_f1 = np.mean(f1s)
                all_f1s.append(subj_mean_f1)
            if aucs:
                subj_mean_auc = np.mean(aucs)
                all_aucs.append(subj_mean_auc)
    
    if all_accs:
        print(f"  Accuracy across subjects: {np.mean(all_accs):.2f}% ± {np.std(all_accs):.2f}%")
    else:
        print("  No valid accuracy scores to report.")
    
    if all_f1s:
        print(f"  F1-score across subjects: {np.mean(all_f1s):.2f} ± {np.std(all_f1s):.2f}")
    else:
        print("  No valid F1-scores to report.")
    
    if all_aucs:
        print(f"  ROC AUC across subjects: {np.mean(all_aucs):.2f} ± {np.std(all_aucs):.2f}")
    else:
        print("  No valid ROC AUCs to report.")

print("\n--- Finished subject-dependent classification ---")

# Plot ROC curves
print(f"\n==================================================")
print(f"📈 AVERAGE ROC CURVES (Subject-Dependent Classification)")
print(f"==================================================")

# Create the main ROC curve plot
fig_roc, ax_roc = plt.subplots(figsize=(8, 5))
ax_roc.set_title('Average ROC Curve for Each Model (Subject-Dependent)')
ax_roc.set_xlabel('False Positive Rate')
ax_roc.set_ylabel('True Positive Rate')
ax_roc.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)

mean_fpr = np.linspace(0, 1, 100)

# Collect handles and labels for the legend
handles = []
labels = []

# Add the 'Chance' line to the legend handles and labels
handles.append(ax_roc.lines[-1]) # Get the last line added, which is 'Chance'
labels.append('Chance')

for name in models_and_params:
    if all_fprs[name]: # Check if there is any ROC data for the model
        tprs_interp = []
        aucs_for_plot_std = [] # Collect AUCs for std dev calculation in plot legend
        
        # Ensure we only iterate up to the minimum number of folds available for this model
        min_folds = min(len(all_fprs[name]), len(all_tprs[name]))

        for i in range(min_folds):
            # Interpolate all ROC curves to the common mean_fpr
            tprs_interp.append(np.interp(mean_fpr, all_fprs[name][i], all_tprs[name][i]))
            tprs_interp[-1][0] = 0.0 # Ensure the curve starts at (0,0)

            # Get the AUC for the current fold for std dev calculation
            # Use the already stored AUCs from subject_results for consistency
            # This part would need to be adapted based on how you store AUCs in subject_results

        if tprs_interp: # Check if there are any interpolated TPRs
            mean_tpr = np.mean(tprs_interp, axis=0)
            mean_tpr[-1] = 1.0 # Ensure the curve ends at (1,1)
            mean_auc = auc(mean_fpr, mean_tpr)
            
            line, = ax_roc.plot(mean_fpr, mean_tpr,
                              label=r'Mean %s ROC (AUC = %0.2f)' % (name, mean_auc),
                              lw=2, alpha=.8)
            handles.append(line)
            labels.append(r'Mean %s ROC (AUC = %0.2f)' % (name, mean_auc))

            std_tpr = np.std(tprs_interp, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
            ax_roc.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2)
            
        else:
            print(f"No valid interpolated TPRs to plot for {name} after processing folds.")
    else:
        print(f"No ROC data collected for {name}. Skipping plotting for this model.")

ax_roc.grid(True)
plt.show() # Display the main ROC plot

# Create a separate plot for the legend
fig_legend = plt.figure(figsize=(6, len(models_and_params) * 0.75 + 1)) # Adjust size based on number of models
ax_legend = fig_legend.add_subplot(111)
ax_legend.legend(handles, labels, loc='center', frameon=False, prop={'size': 10}) # Make text smaller
ax_legend.axis('off') # Hide the axes
plt.show() # Display the legend plot

print("\n--- Plotting finished ---")