In [1]:
import mne # MEG + EEG Analysis & Visualization

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import csd
import seaborn as sns
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="scipy.signal")

from data_preprocessing import preprocess # Function from data_preprocessing script
from CSD_matrix import average_csd

In [2]:
#Preprocess the data: Calculates Freqs & Power Ranges
df1 = preprocess()

In [3]:
#Creates MNE Info object
n_channels = 16
n_times = 7680  #Number of data points
data = np.random.rand(n_channels, n_times)

sfreq = 128 #Sampling Frequency (Provided by the Dataset)
ch_names = [f"Ch{i+1}" for i in range(16)]
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=['eeg'] * 16)

#Creates the RawArray
raw = mne.io.RawArray(data, info)

mapping = {
    "Ch1": "F7", "Ch2": "F3", "Ch3": "F4", "Ch4": "F8",
    "Ch5": "T3", "Ch6": "C3", "Ch7": "Cz", "Ch8": "C4",
    "Ch9": "T4", "Ch10": "T5", "Ch11": "P3", "Ch12": "Pz",
    "Ch13": "P4", "Ch14": "T6", "Ch15": "O1", "Ch16": "O2"}

raw.rename_channels(mapping)
montage = mne.channels.make_standard_montage('standard_1020')
raw.set_montage(montage)

%matplotlib qt
raw.plot_sensors(show_names=True, kind="topomap", sphere=(0, 0, 0, 0.09));


Creating RawArray with float64 data, n_channels=16, n_times=7680
    Range : 0 ... 7679 =      0.000 ...    59.992 secs
Ready.


In [4]:
healthy_patients = df1[df1['schizo'] == 0]['id'].unique().tolist()
schizo_patients = df1[df1['schizo'] == 1]['id'].unique().tolist()

def plot_correlation_matrices(healthy_patients, schizo_patients, wave_option='power_beta', ax=None):
    
    # Calculate average correlation matrices for both healthy and schizophrenic patients
    avg_corr_healthy = average_csd(healthy_patients, df1, wave_option=wave_option)
    avg_corr_schizo = average_csd(schizo_patients, df1, wave_option=wave_option)
    
    #Computes the global vmin and vmax for consistent color scaling
    global_vmin = min(avg_corr_healthy.min().min(), avg_corr_schizo.min().min())
    global_vmax = max(avg_corr_healthy.max().max(), avg_corr_schizo.max().max())

    corr_difference = avg_corr_healthy - avg_corr_schizo
    
    diff_vmin = -0.05
    diff_vmax = 0.175

    # Plot on provided axes
    sns.heatmap(avg_corr_healthy, annot=False, cmap='Oranges', fmt='.2f', xticklabels=True, yticklabels=True,vmin=global_vmin, vmax=global_vmax, ax=ax[0])
    ax[0].set_title(f"Healthy Patients ({wave_option.capitalize()})")
    sns.heatmap(avg_corr_schizo, annot=False, cmap='Oranges', fmt='.2f', xticklabels=True, yticklabels=True,vmin=global_vmin, vmax=global_vmax, ax=ax[1])
    ax[1].set_title(f"Schizophrenic Patients ({wave_option.capitalize()})")
    sns.heatmap(corr_difference, annot=False, cmap='Oranges', fmt='.2f', xticklabels=True, yticklabels=True,vmin=diff_vmin, vmax=diff_vmax, ax=ax[2])
    ax[2].set_title("Difference (Healthy - Schizophrenic)")

    for axis in ax:
        axis.set_xlabel('Regions (Channels)')
        axis.set_ylabel('Regions (Channels)')

"""Subplots"""
fig, axes = plt.subplots(3, 3, figsize=(18, 18))

plot_correlation_matrices(healthy_patients, schizo_patients, wave_option='power_beta', ax=axes[0])
plot_correlation_matrices(healthy_patients, schizo_patients, wave_option='power_alpha', ax=axes[1])
plot_correlation_matrices(healthy_patients, schizo_patients, wave_option='power_theta', ax=axes[2])

plt.tight_layout()
plt.show()