In [13]:
import mne # MEG + EEG Analysis & Visualization

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import csd
import seaborn as sns
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="scipy.signal")

from data_preprocessing import preprocess # Function from data_preprocessing script

In [5]:
df1 = preprocess()

In [15]:
# Step 4: Create MNE Info object
n_channels = 16
n_times = 7680  # number of time points
data = np.random.rand(n_channels, n_times)

sfreq = 128  # Replace with your actual sampling frequency
ch_names = [f"Ch{i+1}" for i in range(16)]  # Create channel names for 16 channels
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=['eeg'] * 16)



# Step 5: Create the RawArray
raw = mne.io.RawArray(data, info)

mapping = {
    "Ch1": "F7", "Ch2": "F3", "Ch3": "F4", "Ch4": "F8",
    "Ch5": "T3", "Ch6": "C3", "Ch7": "Cz", "Ch8": "C4",
    "Ch9": "T4", "Ch10": "T5", "Ch11": "P3", "Ch12": "Pz",
    "Ch13": "P4", "Ch14": "T6", "Ch15": "O1", "Ch16": "O2"}

raw.rename_channels(mapping)
montage = mne.channels.make_standard_montage('standard_1020')
raw.set_montage(montage)
%matplotlib qt
raw.plot_sensors(show_names=True, kind="topomap", sphere=(0, 0, 0, 0.09));


Creating RawArray with float64 data, n_channels=16, n_times=7680
    Range : 0 ... 7679 =      0.000 ...    59.992 secs
Ready.


In [16]:
def cross_spectral_matrix(patient_id, wave_option='power_beta', surrogate=False):
    # Filter data for the selected patient
    patient_data = df1[df1['id'] == patient_id]

    # Create an empty list to hold the power data for each region (channel)
    power_data = []

    sampling_freq = 128

    # Loop through the regions (channels) and extract power data
    for region in patient_data['region'].unique():
        # Filter power data for the current region
        region_data = patient_data[patient_data['region'] == region]
        
        # Append the power data (list) for this region to the list
        power_data.append(region_data[wave_option].values[0])  # Assuming 'power' is a list

    # Convert the list of power data (regions x time points) into a DataFrame
    power_df = pd.DataFrame(power_data).transpose()  # 16 channels x 240 time points

    # Check for NaN values and fill or drop them
    if power_df.isnull().values.any():
        print(f"Warning: Missing values detected in power data for patient {patient_id}")
        power_df = power_df.fillna(0)  # Fill NaN with 0 (you can change this to another method)

    channels = power_df.columns
    csd_matrix = np.zeros((len(channels), len(channels)), dtype=complex)

    #Compute CSD Matrix
    for i, ch1 in enumerate(channels):
        for j, ch2 in enumerate(channels):
            # Compute CSD between two channels
            freqs, csd_values = csd(power_df[ch1], power_df[ch2], fs=sampling_freq, nperseg=256)
            
            # Take the average value of the CSD (optional: could store full spectrum instead)
            csd_matrix[i, j] = np.mean(csd_values)

    #Normalize CSD Matrix
    for i in range(len(channels)):
        for j in range(len(channels)):
            if i != j:
                csd_matrix[i, j] = csd_matrix[i, j] / np.sqrt(csd_matrix[i, i] * csd_matrix[j, j])
    
    np.fill_diagonal(csd_matrix, 1) #I explicitly set the diagonals to 1, because when computing the average across patients, i take the magnitude of the average csd matrix and it messes it up
    # Return the CSD matrix
    return csd_matrix, freqs
    

def average_csd(patient_ids, wave_option='power_beta'):
    csd_sums = None
    count = 0

    # Loop through each patient and calculate their CSD matrix
    for patient_id in patient_ids:
        csd_matrix, _ = cross_spectral_matrix(patient_id, wave_option)  # Call your CSD function

    # Add the current CSD matrix to the cumulative sum
        if csd_sums is None:
            csd_sums = csd_matrix
        else:
            csd_sums += csd_matrix

        count += 1

    # Calculate the average CSD matrix
    #Only considering the phase relationship between both signals
    average_csd_matrix = np.abs(csd_sums / count) #Magnitude of the complex values
    # average_csd_matrix = np.angle(csd_sums / count)

    region_labels = ["F7", "F3", "F4", "F8", "T3", "C3", "Cz", "C4", "T4", "T5", "P3", "Pz", "P4", "T6", "O1", "O2"]
    average_csd_df = pd.DataFrame(average_csd_matrix, columns=region_labels, index=region_labels)

    return average_csd_df

healthy_patients = df1[df1['schizo'] == 0]['id'].unique().tolist()
schizo_patients = df1[df1['schizo'] == 1]['id'].unique().tolist()

def plot_correlation_matrices(healthy_patients, schizo_patients, wave_option='power_beta', ax=None):
    # Calculate average correlation matrices for both healthy and schizophrenic patients
    avg_corr_healthy = average_csd(healthy_patients, wave_option=wave_option)
    avg_corr_schizo = average_csd(schizo_patients, wave_option=wave_option)
    
    # Compute the global vmin and vmax for consistent color scaling
    global_vmin = min(avg_corr_healthy.min().min(), avg_corr_schizo.min().min())
    global_vmax = max(avg_corr_healthy.max().max(), avg_corr_schizo.max().max())

    corr_difference = avg_corr_healthy - avg_corr_schizo
    
    diff_vmin = 0.06  # Replace with your desired minimum
    diff_vmax = 0.175

    # Plot on provided axes
    sns.heatmap(avg_corr_healthy, annot=False, cmap='Oranges', fmt='.2f', xticklabels=True, yticklabels=True,vmin=global_vmin, vmax=global_vmax, ax=ax[0])
    ax[0].set_title(f"Healthy Patients ({wave_option.capitalize()})")
    sns.heatmap(avg_corr_schizo, annot=False, cmap='Oranges', fmt='.2f', xticklabels=True, yticklabels=True,vmin=global_vmin, vmax=global_vmax, ax=ax[1])
    ax[1].set_title(f"Schizophrenic Patients ({wave_option.capitalize()})")
    sns.heatmap(corr_difference, annot=False, cmap='Oranges', fmt='.2f', xticklabels=True, yticklabels=True,vmin=diff_vmin, vmax=diff_vmax, ax=ax[2])
    ax[2].set_title("Difference (Healthy - Schizophrenic)")

    for axis in ax:
        axis.set_xlabel('Regions (Channels)')
        axis.set_ylabel('Regions (Channels)')
# Create the figure and subplots
fig, axes = plt.subplots(3, 3, figsize=(18, 18))

# Plot for beta
plot_correlation_matrices(healthy_patients, schizo_patients, wave_option='power_beta', ax=axes[0])

# Plot for alpha
plot_correlation_matrices(healthy_patients, schizo_patients, wave_option='power_alpha', ax=axes[1])

# Plot for theta
plot_correlation_matrices(healthy_patients, schizo_patients, wave_option='power_theta', ax=axes[2])

plt.tight_layout()
plt.show()