<h1>Bayesian PAC (Group Comparison)</h1>

The objective of this code is to classify the subjects using the connections that have been found to be significant and/or tend to be significant. A non-parametric permutation test is then applied to evaluate whether the classification results are the result of chance or not.

<h2>Initialization</h2>

<h3>Libraries</h3>

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import seaborn as sns

from matplotlib.patches import Arc, Circle, ConnectionPatch
from scipy.stats import norm
from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from xgboost import XGBClassifier


<h3>Global Vars</h3>

In [2]:
G_nChannels = 31       # We have 32 channels though one (Cz) is used as reference for offset and normalization purposes.
G_Alpha = 0.1          # Reference p-Value to evaluate the z-Score threshold.
G_Alpha_near = 0.25    # Reference p-Value to evaluate the z-Score threshold for near significant connections.
G_nSurrogates = 200    # Number of surrogates to evaluate.
G_nPermutations = 1000 # Number of iterations of the permutation test.
G_EEG_labels = ['Fp1','Fp2','F7','F3','Fz','F4','F8','FC5','FC1','FC2','FC6','T7','C3','C4','T8','TP9','CP5','CP1','CP2','CP6','TP10','P7','P3','PZ','P4','P8','PO9','O1','OZ','O2','PO10']

<h3>(Optional) Verification of the Bonferroni-corrected alpha selection</h3>

In [None]:
# Number of connections
n_comparisons = G_nChannels*G_nChannels

# Bonferroni-corrected alpha
alpha_bonferroni = G_Alpha / n_comparisons
z_alpha_bonferroni = norm.ppf(1 - alpha_bonferroni)
print(f"\nFor G_Alpha = {G_Alpha} with {n_comparisons} comparisons:")
print(f"Bonferroni-corrected alpha = {alpha_bonferroni:.4e}")
print(f"Bonferroni-corrected z-score: {z_alpha_bonferroni:.4e}")

# Bonferroni-corrected alpha (near significant)
alpha_bonferroni_near = G_Alpha_near / n_comparisons
z_alpha_bonferroni_near = norm.ppf(1 - alpha_bonferroni_near)
print(f"\nFor G_Alpha_near = {G_Alpha_near} with {n_comparisons} comparisons:")
print(f"Bonferroni-corrected alpha_near = {alpha_bonferroni_near:.4e}")
print(f"Bonferroni-corrected z-score_near: {z_alpha_bonferroni_near:.4e}")

<h3>Auxiliary Functions</h3>

<h4>Classification</h4>

In [4]:
# Define a function for training and evaluating XGBoost
def train_evaluate_xgboost(X, y):
    """
    Train and evaluate an XGBoost classifier on the given data.

    Parameters:
    X (numpy.ndarray): Feature matrix of shape (n_samples, n_features).
    y (numpy.ndarray): Response vector of shape (n_samples,).

    Returns:
    tuple: A tuple containing:
        - accuracy (float): The accuracy score of the model on the test set.
        - balanced_accuracy (float): The balanced accuracy score on the test set.
        - auc (float): The AUC score on the test set.
        - precision (float): The precision score on the test set.
        - recall (float): The recall score (sensitivity) on the test set.
        - specificity (float): The specificity score on the test set.
        - f1 (float): The F1-score of the model on the test set.
    """
    
    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    # Initialize the XGBoost classifier
    model = XGBClassifier(eval_metric='logloss')

    # Train the model on the training data
    model.fit(X_train, y_train)

    # Predict the labels for the test set
    y_pred = model.predict(X_test)
    y_pred_proba = model.predict_proba(X_test)[:, 1]

    # Calculate evaluation metrics on the test set
    accuracy = accuracy_score(y_test, y_pred)
    balanced_accuracy = balanced_accuracy_score(y_test, y_pred)
    auc = roc_auc_score(y_test, y_pred_proba)
    precision = precision_score(y_test, y_pred, zero_division=1)
    recall = recall_score(y_test, y_pred, zero_division=1)  # Sensitivity
    specificity = recall_score(y_test, y_pred, pos_label=0, zero_division=1)  # Specificity
    f1 = f1_score(y_test, y_pred, zero_division=1)

    # Return the evaluation metrics as a tuple
    return accuracy, balanced_accuracy, auc, precision, recall, specificity, f1

<h4>Visualization</h4>

In [5]:
def self_connection_plot(ax, pos, node_radius, z_value, color='black', linewidth=6, center=(0,0)):
    """
    Draws a self-connection as an arc between the intersection points of two circles:
    the node circle and another one shifted towards the center of the connectogram.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        Axis where to draw.
    pos : tuple
        (x, y) coordinates of the node.
    node_radius : float
        Radius of the node.
    z_value : float
        Z value to annotate (optional).
    color : str
        Color of the arc.
    linewidth : float
        Width of the arc.
    center : tuple
        Center of the connectogram.
    """
    px, py = pos
    cx, cy = center

    # Vector towards the center
    v = np.array([cx - px, cy - py])
    v = v / np.linalg.norm(v)
    # Center of the shifted circle
    ghost_center = np.array(pos) + v * node_radius * 1.2

    # Original and shifted circles
    x0, y0 = pos
    x1, y1 = ghost_center
    r = node_radius

    # Equation of intersection of two circles
    d = np.linalg.norm(ghost_center - pos)
    if d > 2*r:
        # No intersection, abort
        return

    # Intersection points
    a = d/2
    h = np.sqrt(r**2 - a**2)
    mid = (np.array(pos) + ghost_center) / 2
    perp = np.array([-(y1-y0), x1-x0]) / d

    inter1 = mid + h * perp
    inter2 = mid - h * perp

    # Angles of intersection points relative to the center of the shifted circle
    angle1 = np.degrees(np.arctan2(inter1[1]-y1, inter1[0]-x1))
    angle2 = np.degrees(np.arctan2(inter2[1]-y1, inter2[0]-x1))

    # Choose the arc that passes closest to the center of the connectogram
    # (smallest absolute angle relative to the direction to the center)
    dir_to_center = np.degrees(np.arctan2(cy-y1, cx-x1))
    # Determine arc direction
    angle_diff = (angle2 - angle1) % 360
    if abs((angle1 + angle_diff/2) - dir_to_center) > 90:
        # Reverse direction if the arc doesn't pass through the center
        angle1, angle2 = angle2, angle1

    # Draw the arc
    arc = Arc(ghost_center, 2*r, 2*r, angle=0, theta1=angle1, theta2=angle2,
              color=color, lw=linewidth, zorder=10)
    ax.add_patch(arc)

    # Draw the arrow at the end of the arc
    # Calculate end point and tangent
    theta = np.radians(angle2)
    end = np.array([x1 + r * np.cos(theta), y1 + r * np.sin(theta)])
    tangent = np.array([-np.sin(theta), np.cos(theta)])
    # Small offset for the arrow
    arrow_start = end - tangent * node_radius * 0.3
    ax.annotate(
        '', xy=end, xytext=arrow_start,
        arrowprops=dict(arrowstyle='-|>', color=color, lw=linewidth, shrinkA=0, shrinkB=0),
        zorder=11
    )

    # Annotate the z value near the arc
    label_pos = ghost_center + v * r * 1.2
    ax.text(label_pos[0], label_pos[1],
            f'{z_value:.1f}',
            ha='center', va='center',
            fontsize=16, fontweight='bold',
            bbox=dict(facecolor='white', edgecolor='none', alpha=0.8, pad=2),
            color=color,
            zorder=12)

def create_advanced_connectogram(matrix, labels, save_str='default', significant_threshold=3.7, near_threshold=3.4, title=None):
    """
    Create an advanced connectogram visualization using matplotlib.

    This function creates a circular visualization of connections between nodes (electrodes),
    with support for different connection strengths, internal arcs, and automatic label placement.
    Self-connections are drawn as external arcs using self_connection_plot().

    Parameters
    ----------
    matrix : numpy.ndarray
        Matrix of Bonferroni values (destinations × sources)
        The matrix should be square with dimensions matching the length of labels
    labels : list
        List of electrode labels (e.g., ['F3', 'FC5', etc.])
    save_str : str, optional
        String for saving the output files, by default 'default'
    significant_threshold : float, optional
        Threshold for significant connections (|z| ≥ 3.7, p ≤ 0.1), by default 3.7
    near_threshold : float, optional
        Threshold for near-significant connections (|z| ≥ 3.4, p ≤ 0.2), by default 3.4
    title : str, optional
        Title for the plot ('Controls' or 'Reading Difficulties'), by default None

    Returns
    -------
    None
        The function saves the plot as PNG and EPS files and displays it

    Notes
    -----
    - Significant connections (|z| ≥ 3.7) are shown in black with linewidth 6
    - Near-significant connections (3.4 ≤ |z| < 3.7) are shown in gray with linewidth 3
    - Connections are automatically curved inward to avoid edge overlaps
    - Labels are positioned close to their connections
    - Far connections (node distance > 5) are drawn as straight lines
    - Self-connections are drawn as external arcs using self_connection_plot()
    """
    fig, ax = plt.subplots(figsize=(18, 18))

    if title:
        plt.title(title, pad=20, fontsize=32, fontweight='bold')

    n_nodes = len(labels)
    radius = 10
    node_radius = 0.7
    angles = np.linspace(0, 2*np.pi, n_nodes, endpoint=False)

    node_positions = {}
    for i, label in enumerate(labels):
        x = radius * np.cos(angles[i])
        y = radius * np.sin(angles[i])
        node_positions[label] = (x, y)

    def nodes_between(src_idx, dst_idx):
        dist = abs(dst_idx - src_idx)
        return min(dist, n_nodes - dist)

    def get_normal_vector(dx, dy):
        norm = np.sqrt(dx*dx + dy*dy)
        if norm < 1e-10:
            return np.array([0, 1])
        return np.array([-dy/norm, dx/norm])

    def get_connection_style(src_pos, dst_pos, node_dist):
        mid_x = (src_pos[0] + dst_pos[0]) / 2
        mid_y = (src_pos[1] + dst_pos[1]) / 2
        to_mid = np.array([mid_x, mid_y])
        dist_from_center = np.sqrt(np.sum(to_mid**2))
        if node_dist > 5:
            return "arc3,rad=0"
        if dist_from_center > radius * 0.9:
            if node_dist <= 1:
                return "arc3,rad=-0.8"
            elif node_dist <= 3:
                return "arc3,rad=-0.6"
            else:
                return "arc3,rad=-0.4"
        else:
            if node_dist <= 1:
                return "arc3,rad=0.5"
            elif node_dist <= 3:
                return "arc3,rad=0.3"
            else:
                return "arc3,rad=0.2"

    drawn_edges = set()

    # Draw near-significant and significant connections (excluding self-connections)
    for i, src in enumerate(labels):
        for j, dst in enumerate(labels):
            if i == j:
                continue  # Skip self-connections here
            zval = abs(matrix[j, i])
            if near_threshold <= zval < significant_threshold:
                if (src, dst) not in drawn_edges:
                    src_pos = node_positions[src]
                    dst_pos = node_positions[dst]
                    node_dist = nodes_between(i, j)
                    connectionstyle = get_connection_style(src_pos, dst_pos, node_dist)
                    con = ConnectionPatch(
                        xyA=src_pos, xyB=dst_pos,
                        coordsA="data", coordsB="data",
                        axesA=ax, axesB=ax,
                        color='#CCCCCC',
                        linewidth=3,
                        connectionstyle=connectionstyle,
                        arrowstyle='<|-',
                        mutation_scale=20,
                        shrinkA=node_radius*40,
                        shrinkB=node_radius*40,
                        zorder=1
                    )
                    ax.add_patch(con)
                    # Add value label
                    mid_x = (src_pos[0] + dst_pos[0]) / 2
                    mid_y = (src_pos[1] + dst_pos[1]) / 2
                    dx = dst_pos[0] - src_pos[0]
                    dy = dst_pos[1] - src_pos[1]
                    normal = get_normal_vector(dx, dy)
                    offset = 0.3
                    label_pos = np.array([mid_x, mid_y]) + normal * offset
                    plt.text(label_pos[0], label_pos[1],
                            f'{matrix[j, i]:.1f}',
                            horizontalalignment='center',
                            verticalalignment='center',
                            fontsize=24,
                            fontweight='bold',
                            color='#666666',
                            bbox=dict(facecolor='white',
                                    edgecolor='none',
                                    alpha=0.8,
                                    pad=2),
                            zorder=4)
                    drawn_edges.add((src, dst))

    for i, src in enumerate(labels):
        for j, dst in enumerate(labels):
            if i == j:
                continue # Skip self connections here
            zval = abs(matrix[j, i])
            if zval >= significant_threshold:
                if (src, dst) not in drawn_edges:
                    src_pos = node_positions[src]
                    dst_pos = node_positions[dst]
                    node_dist = nodes_between(i, j)
                    connectionstyle = get_connection_style(src_pos, dst_pos, node_dist)
                    con = ConnectionPatch(
                        xyA=src_pos, xyB=dst_pos,
                        coordsA="data", coordsB="data",
                        axesA=ax, axesB=ax,
                        color='black',
                        linewidth=6,
                        connectionstyle=connectionstyle,
                        arrowstyle='<|-',
                        mutation_scale=20,
                        shrinkA=node_radius*40,
                        shrinkB=node_radius*40,
                        zorder=2
                    )
                    ax.add_patch(con)
                    # Add value label
                    mid_x = (src_pos[0] + dst_pos[0]) / 2
                    mid_y = (src_pos[1] + dst_pos[1]) / 2
                    dx = dst_pos[0] - src_pos[0]
                    dy = dst_pos[1] - src_pos[1]
                    normal = get_normal_vector(dx, dy)
                    offset = 0.3
                    label_pos = np.array([mid_x, mid_y]) + normal * offset
                    plt.text(label_pos[0], label_pos[1],
                            f'{matrix[j, i]:.1f}',
                            horizontalalignment='center',
                            verticalalignment='center',
                            fontsize=24,
                            fontweight='bold',
                            bbox=dict(facecolor='white',
                                    edgecolor='none',
                                    alpha=0.8,
                                    pad=2),
                            zorder=4)
                    drawn_edges.add((src, dst))

    # Draw self-connections
    for i, label in enumerate(labels):
        zval = abs(matrix[i, i])
        if zval >= significant_threshold:
            self_connection_plot(
                ax=ax,
                pos=node_positions[label],
                node_radius=node_radius,
                z_value=matrix[i, i],
                color='black',
                linewidth=6,
                center=(0,0)
            )
        elif zval >= near_threshold:
            self_connection_plot(
                ax=ax,
                pos=node_positions[label],
                node_radius=node_radius,
                z_value=matrix[i, i],
                color='#CCCCCC',
                linewidth=3,
                center=(0,0)
            )

    # Draw nodes
    for label, pos in node_positions.items():
        for offset in [0.15, 0.1, 0.05]:
            shadow = Circle(
                (pos[0]-offset/2, pos[1]-offset/2),
                radius=node_radius,
                facecolor='gray',
                alpha=0.1,
                zorder=5
            )
            ax.add_patch(shadow)
        circle = Circle(
            pos,
            radius=node_radius,
            facecolor='white',
            edgecolor='red',
            linewidth=2,
            zorder=6
        )
        ax.add_patch(circle)
        plt.text(pos[0], pos[1],
                label,
                horizontalalignment='center',
                verticalalignment='center',
                fontsize=24,
                fontweight='bold',
                zorder=7)

    legend_elements = [
        plt.Line2D([0], [0], color='black', linewidth=6, 
                  label=f'|z| ≥ {significant_threshold:.1f} (p ≤ 0.1)'),
        plt.Line2D([0], [0], color='#CCCCCC', linewidth=3, 
                  label=f'|z| ≥ {near_threshold:.1f} (p ≤ 0.2)')
    ]
    legend = ax.legend(handles=legend_elements,
                      loc='upper right',
                      title='Bonferroni Value',
                      fontsize=24,
                      title_fontsize=28,
                      frameon=True,
                      facecolor='white',
                      edgecolor='none')
    legend.get_frame().set_alpha(1.0)
    ax.set_aspect('equal')
    ax.axis('off')
    limit = radius + 2
    ax.set_xlim(-limit, limit)
    ax.set_ylim(-limit, limit)
    plt.savefig(f'Figures/connectogram_{save_str}.png',
                dpi=300,
                bbox_inches='tight',
                facecolor='white',
                edgecolor='none')
    plt.savefig(f'Figures/connectogram_{save_str}.eps',
                format='eps',
                dpi=300,
                bbox_inches='tight',
                facecolor='white',
                edgecolor='none')
    plt.show()

<h2>Reading Data</h2>

Once all the probability matrices of all the subjects have been obtained, they are converted into row vectors and concatenated into a matrix of results according to the class to which they belong. In our case we wanted to compare Control subjects versus subjects with Reading Difficulties for different auditory stimuli (4.8Hz, 16Hz and 40Hz):

In [6]:
# List of files for each frequency scenario
filename_save = [['results_4_8Hz_Controls.pkl','results_4_8Hz_Difficulties.pkl'],
                 ['results_16Hz_Controls.pkl','results_16Hz_Difficulties.pkl'],
                 ['results_40Hz_Controls.pkl','results_40Hz_Difficulties.pkl']]

<h2>Main Pipeline</h2>

In [None]:
for freq_idx, (controls_file, reading_difficulties_file) in enumerate(filename_save):

    #### Controls ####
    ##################

    # Load Data
    with open(controls_file, 'rb') as f:
        controls_probability_matrices = pickle.load(f)
    
    # Convert the list of matrices to a NumPy array
    controls_probability_matrices_array = np.array(controls_probability_matrices)

    # Compute the real means for each column (source-destination connections)
    controls_real_means = np.mean(controls_probability_matrices_array, axis=0)

    # Initialize array to store the means from the permutations
    controls_surrogate_means = np.zeros((G_nSurrogates, controls_probability_matrices_array.shape[1]))

    # Generate the permutations and calculate means
    for s in range(G_nSurrogates):
        controls_permuted_matrix = np.random.permutation(controls_probability_matrices_array.flatten()).reshape(controls_probability_matrices_array.shape)
        controls_surrogate_means[s, :] = np.mean(controls_permuted_matrix, axis=0)
    
    # Calculate the mean and standard deviation from the surrogate data (null distribution)
    controls_surrogate_mean = np.mean(controls_surrogate_means, axis=0)
    controls_surrogate_std = np.std(controls_surrogate_means, axis=0)

    # Compute z-scores for the real connections
    controls_z_scores = np.zeros(controls_real_means.shape)
    controls_non_zero_std = controls_surrogate_std > 0  # Avoid division by zero
    controls_z_scores[controls_non_zero_std] = (controls_real_means[controls_non_zero_std] - controls_surrogate_mean[controls_non_zero_std]) / controls_surrogate_std[controls_non_zero_std]

    # Second Bonferroni correction to determine the significance threshold
    controls_z_threshold_bonferroni = norm.ppf(1 - G_Alpha / controls_probability_matrices_array.shape[1])

    # Identify significant connections using Bonferroni correction
    controls_significant_connections_bonferroni = np.where(np.abs(controls_z_scores) > controls_z_threshold_bonferroni)

    # Display significant connections based on Bonferroni correction
    print(f"[CONTROLS] Bonferroni corrected z-threshold: {controls_z_threshold_bonferroni}")
    print("[CONTROLS] Significant connections (source -> destination, z-score):")
    for idx in controls_significant_connections_bonferroni[0]:
        i, j = divmod(idx, G_nChannels)
        print(f"\tConnection from {G_EEG_labels[i]} to {G_EEG_labels[j]} with z-score: {controls_z_scores[idx]:.3f}")
    
    # Convert z-scores to square matrices of size G_nChannels x G_nChannels
    controls_z_scores_matrix = controls_z_scores.reshape(G_nChannels, G_nChannels)

    # Identify connections above near significant threshold
    controls_significant_connections_fixed = np.where(np.abs(controls_z_scores_matrix) > z_alpha_bonferroni_near)
    controls_significant_connections_posterior_analysis = {(i, j) for i, j in zip(controls_significant_connections_fixed[0], controls_significant_connections_fixed[1])}
    print(f"[CONTROLS] Evaluating connections (source -> destination) with a z-score above near significant threshold ({z_alpha_bonferroni_near})")
    for i, j in controls_significant_connections_posterior_analysis:
        print(f"\tConnection from {G_EEG_labels[i]} to {G_EEG_labels[j]} with z-score: {controls_z_scores_matrix[i, j]:.3f}")
    
    # Connectogram showing results
    create_advanced_connectogram(controls_z_scores_matrix, G_EEG_labels, 
                                 save_str=f"{controls_file}_Updated",
                                 title="Controls")

    #### Reading Difficulties ####
    ##############################

    # Load Data
    with open(reading_difficulties_file, 'rb') as f:
        reading_difficulties_probability_matrices = pickle.load(f)
    
    # Convert the list of matrices to a NumPy array
    reading_difficulties_probability_matrices_array = np.array(reading_difficulties_probability_matrices)

    # Compute the real means for each column (source-destination connections)
    reading_difficulties_real_means = np.mean(reading_difficulties_probability_matrices_array, axis=0)

    # Initialize array to store the means from the permutations
    reading_difficulties_surrogate_means = np.zeros((G_nSurrogates, reading_difficulties_probability_matrices_array.shape[1]))

    # Generate the permutations and calculate means
    for s in range(G_nSurrogates):
        reading_difficulties_permuted_matrix = np.random.permutation(reading_difficulties_probability_matrices_array.flatten()).reshape(reading_difficulties_probability_matrices_array.shape)
        reading_difficulties_surrogate_means[s, :] = np.mean(reading_difficulties_permuted_matrix, axis=0)
    
    # Calculate the mean and standard deviation from the surrogate data (null distribution)
    reading_difficulties_surrogate_mean = np.mean(reading_difficulties_surrogate_means, axis=0)
    reading_difficulties_surrogate_std = np.std(reading_difficulties_surrogate_means, axis=0)

    # Compute z-scores for the real connections
    reading_difficulties_z_scores = np.zeros(reading_difficulties_real_means.shape)
    reading_difficulties_non_zero_std = reading_difficulties_surrogate_std > 0  # Avoid division by zero
    reading_difficulties_z_scores[reading_difficulties_non_zero_std] = (reading_difficulties_real_means[reading_difficulties_non_zero_std] - reading_difficulties_surrogate_mean[reading_difficulties_non_zero_std]) / reading_difficulties_surrogate_std[reading_difficulties_non_zero_std]

    # Second Bonferroni correction to determine the significance threshold
    reading_difficulties_z_threshold_bonferroni = norm.ppf(1 - G_Alpha / reading_difficulties_probability_matrices_array.shape[1])

    # Identify significant connections using Bonferroni correction
    reading_difficulties_significant_connections_bonferroni = np.where(np.abs(reading_difficulties_z_scores) > reading_difficulties_z_threshold_bonferroni)

    # Display significant connections based on Bonferroni correction
    print(f"[READING DIFFICULTIES] Bonferroni corrected z-threshold: {reading_difficulties_z_threshold_bonferroni}")
    print("[READING DIFFICULTIES] Significant connections (source -> destination, z-score):")
    for idx in reading_difficulties_significant_connections_bonferroni[0]:
        i, j = divmod(idx, G_nChannels)
        print(f"\tConnection from {G_EEG_labels[i]} to {G_EEG_labels[j]} with z-score: {reading_difficulties_z_scores[idx]:.3f}")

    # Convert z-scores to square matrices of size G_nChannels x G_nChannels
    reading_difficulties_z_scores_matrix = reading_difficulties_z_scores.reshape(G_nChannels, G_nChannels)

    # Identify connections above near significant threshold
    reading_difficulties_significant_connections_fixed = np.where(np.abs(reading_difficulties_z_scores_matrix) > z_alpha_bonferroni_near)
    reading_difficulties_significant_connections_posterior_analysis = {(i, j) for i, j in zip(reading_difficulties_significant_connections_fixed[0], reading_difficulties_significant_connections_fixed[1])}
    print(f"[READING DIFFICULTIES] Evaluating connections (source -> destination) with a z-score above near significant threshold ({z_alpha_bonferroni_near})")
    for i, j in reading_difficulties_significant_connections_posterior_analysis:
        print(f"\tConnection from {G_EEG_labels[i]} to {G_EEG_labels[j]} with z-score: {reading_difficulties_z_scores_matrix[i, j]:.3f}")
    
    # Connectogram showing results
    create_advanced_connectogram(reading_difficulties_z_scores_matrix, G_EEG_labels, 
                                 save_str=f"{reading_difficulties_file}_Updated",
                                 title="Reading_Difficulties")
    
    #### Classification ####
    ########################

    # Combine significant connections from both groups for the current frequency
    all_significant_connections_posterior_analysis = controls_significant_connections_posterior_analysis.union(reading_difficulties_significant_connections_posterior_analysis)

    # Initialize X and y for the current frequency scenario
    X = []
    y = []

    # Convert the significant (i, j) pairs into linear indices for accessing flat arrays
    linear_indices = [i * G_nChannels + j for i, j in all_significant_connections_posterior_analysis]

    # Populate X and y for controls group
    for control_prob_matrix in controls_probability_matrices_array:
        # Extract probability values for each significant connection using linear indices
        row = [control_prob_matrix[idx] for idx in linear_indices]
        X.append(row)
        y.append(0)  # Labeled as control

    # Populate X and y for reading difficulties group
    for reading_difficulties_prob_matrix in reading_difficulties_probability_matrices_array:
        # Extract probability values for each significant connection using linear indices
        row = [reading_difficulties_prob_matrix[idx] for idx in linear_indices]
        X.append(row)
        y.append(1)  # Labeled as reading difficulties

    # Convert X and y to numpy arrays
    X = np.array(X)
    y = np.array(y)

    # Shuffle X and y together to randomize the order for robust training
    X, y = shuffle(X, y, random_state=42)

    # Print the shape of X and y to confirm correct structure
    print(f"Predictor matrix (X) shape for frequency {freq_idx + 1}: {X.shape}")
    print(f"Response vector (y) shape for frequency {freq_idx + 1}: {y.shape}")

    # Train and evaluate the model on the real data
    real_accuracy, real_balanced_accuracy, real_auc, real_precision, real_recall, real_specificity, real_f1 = train_evaluate_xgboost(X, y)
    print("\n=== Classification Results ===")
    print(f"\tAccuracy: {real_accuracy:.3f}")
    print(f"\tBalanced Accuracy: {real_balanced_accuracy:.3f}")
    print(f"\tAUC: {real_auc:.3f}")
    print(f"\tPrecision: {real_precision:.3f}")
    print(f"\tRecall (Sensitivity): {real_recall:.3f}")
    print(f"\tSpecificity: {real_specificity:.3f}")
    print(f"\tF1-Score: {real_f1:.3f}")

    #### Permutation Test ####
    ##########################

    # Temporarily save the classification metrics for each iteration of the permutation test
    accuracy_permuted = []
    balanced_accuracy_permuted = []
    auc_permuted = []
    precision_permuted = []
    recall_permuted = []
    specificity_permuted = []
    f1_permuted = []

    # Perform permutation test
    for perm in range(G_nPermutations):
        # Randomly swap labels
        y_permuted = np.random.permutation(y)
        # Training and evaluation of the model with swaped labels
        perm_accuracy, perm_balanced_accuracy, perm_auc, perm_precision, perm_recall, perm_specificity, perm_f1 = train_evaluate_xgboost(X, y_permuted)
        # Save metrics
        accuracy_permuted.append(perm_accuracy)
        balanced_accuracy_permuted.append(perm_balanced_accuracy)
        auc_permuted.append(perm_auc)
        precision_permuted.append(perm_precision)
        recall_permuted.append(perm_recall)
        specificity_permuted.append(perm_specificity)
        f1_permuted.append(perm_f1)

    # Converting lists to NumPy arrays to calculate statistics
    accuracy_permuted = np.array(accuracy_permuted)
    balanced_accuracy_permuted = np.array(balanced_accuracy_permuted)
    auc_permuted = np.array(auc_permuted)
    precision_permuted = np.array(precision_permuted)
    recall_permuted = np.array(recall_permuted)
    specificity_permuted = np.array(specificity_permuted)
    f1_permuted = np.array(f1_permuted)

    # Calculate the p-values for each metric
    p_accuracy = np.mean(accuracy_permuted >= real_accuracy)
    p_balanced_accuracy = np.mean(balanced_accuracy_permuted >= real_balanced_accuracy)
    p_auc = np.mean(auc_permuted >= real_auc)
    p_precision = np.mean(precision_permuted >= real_precision)
    p_recall = np.mean(recall_permuted >= real_recall)
    p_specificity = np.mean(specificity_permuted >= real_specificity)
    p_f1 = np.mean(f1_permuted >= real_f1)

    print("\n=== Permutation Test Results (p-values) ===")
    print(f"p-value for Accuracy: {p_accuracy:.3f}")
    print(f"p-value for Balanced Accuracy: {p_balanced_accuracy:.3f}")
    print(f"p-value for AUC: {p_auc:.3f}")
    print(f"p-value for Precision: {p_precision:.3f}")
    print(f"p-value for Recall (Sensitivity): {p_recall:.3f}")
    print(f"p-value for Specificity: {p_specificity:.3f}")
    print(f"p-value for F1-Score: {p_f1:.3f}")