### Overview - Graphing brain connectivity in schizophrenia from EEG data - Create DTF Graphs and Gold Table

EEG analysis was carried out using:
1. the raw EEG data, 
as well as the re-referenced data: 
2. the Average Reference Method and
3. the Zero Reference Method.
This allowed us to explore how the choice of reference electrode impacts connectivity outcomes.

EEG data were analyzed using three connectivity methods: Phase-Locking Value (PLV), Phase-Lag Index (PLI), and Directed Transfer Function (DTF), and statistical indices based on graph theory. 

##### In this notebook we will:
  * Graph analysis of EEG data measuring connectivity using three connectivity measures:
    * Directed Transfer Function (DTF)
    * Phase-Locking Value (PLV)
    * Phase-Lag Index (PLI)
##### This Notebook will use Directed Transfer Function (DTF)

###### We need to convert the REST Reference Method PySpark Dataframe to a Pandas Dataframe so we can use with the scipy, mne and numpy packages

In [0]:
# Load the data from the Butterworth Filtered Data REST Tables
band_names = ["delta", "theta", "alpha", "beta", "gamma"]

df_bands_rest = {}
# Create Pandas DataFrames
for band in band_names:
    df_bands_rest[band] = spark.sql(f"SELECT * FROM main.solution_accelerator.butter_rest_{band}_gold ORDER BY time ASC").toPandas()
    # Get distinct values from the 'patient_id' column
    patient_ids = df_bands_rest[band]['patient_id'].unique()
    print(f"patient_ids:::{patient_ids}")
# display(df_bands_rest[band])

###### We need to convert the Average Reference Method PySpark Dataframe to a Pandas Dataframe so we can use with the scipy, mne and numpy packages

In [0]:
# Create PySpark DataFrames
band_names = ["delta", "theta", "alpha", "beta", "gamma"]

df_bands_avg = {}

for band in band_names:
    df_bands_avg[band] = spark.sql(f"SELECT * FROM main.solution_accelerator.butter_avg_{band}_gold ORDER BY time ASC").toPandas()
    # Get distinct values from the 'patient_id' column
    patient_ids = df_bands_avg[band]['patient_id'].unique()
    print(f"patient_ids:::{patient_ids}")
    display(df_bands_avg[band])


##### Graph analysis of EEG data measuring connectivity using Directed Transfer Function (DTF) connectivity measure

##### Directed Transfer Function (DTF)
Directed Transfer Function (DTF) is a frequency-domain measure derived from multivariate autoregressive (MVAR) modeling of EEG signals. It estimates the directed influence or connectivity between different brain regions in the frequency domain.

In [0]:
# Helper library with many built-in functions
%pip install mne

In [0]:
import mne

band_names = ["delta", "theta", "alpha", "beta", "gamma"]

# Sampling rate in Hz
sfreq = 250

mne_raw_all = {}

# Create Pandas DataFrames
for band in band_names:
    mne_raw_all[band] = {}
    
    # Get channel names , extract patient_id column
    ch_names = [c for c in df_bands_rest[band].columns if c not in ['patient_id', 'time']]
    # print(f"ch_names:::{ch_names}")

    # Extract patient_id column
    pt_names = list(df_bands_rest[band]['patient_id'].unique())
    # print(f"patient_names:::{pt_names}")

    for pt in pt_names:
        print("PATIENT_ID::", pt)
        df_pt_data = df_bands_rest[band].loc[df_bands_rest[band]['patient_id'] == pt]
        df_pt_data = df_pt_data.drop(columns=['patient_id', 'time'])    
        # print("LEN::", len(df_pt_data.index))
        
        # Convert Pandas Dataframe to Numpy Array for each patient
        np_pt_data = df_pt_data.to_numpy() 

        # Create an info structure needed by MNE
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
        
        # Create the MNE Raw object
        mne_raw_pt = mne.io.RawArray(np_pt_data.T, info)
        
        # The mne raw data object gives us time, assess it as `data, times = raw[:]`  
        # Channel mapping
        mne_raw_all[band][pt] = mne_raw_pt.set_montage('standard_1020')
        
        # # Plot the data so we can compare graphs to reference methods later
        # print(f"Patient ID: {pt}")
        # mne_raw_all[band][pt] .plot(scalings=dict(eeg=50), title=(f"Patient ID: {pt}"), start=150, duration=100)
        # print(f"Patient ID: {pt}")
        # mne_raw_all[band][pt].plot_sensors(ch_type="eeg", title=(f"Patient ID: {pt}"))
        # print(f"Patient ID: {pt}")
        # spectrum = mne_raw_all[band][pt].compute_psd().plot(average=True, picks="data", exclude="bads", amplitude=False)
        
# Now we have our MNE Raw objects and are ready for further analysis


In [0]:
import numpy as np
import pandas as pd
from scipy.signal import hilbert

# Function to compute analytic signal using Hilbert transform
def analytic_signal(sig):
    return hilbert(sig)

# Function to compute covariance matrix
def covariance_matrix(H):
    return np.cov(H)

# Function to compute Directed Transfer Function (DTF)
def dtf(H):
    C = covariance_matrix(H)
    Pxx = np.diag(C)
    D = np.dot(np.dot(H, np.linalg.inv(C)), H.T)
    D /= Pxx[:, None]
    D = np.abs(D)
    return D

def compute_dtf_matrix(pd_df):
    """
    Compute the Directed Transfer Function (DTF) adjacency matrix for multiple EEG channels.

    Parameters:
    pd_df (pandas DataFrame): DataFrame containing EEG signals as columns.

    Returns:
    numpy array: Adjacency matrix where entry (i, j) represents the DTF from channel j to channel i.
    """

    # Prepare a 2D array where each column is a signal
    signals = pd_df.values
    display(signals)
    H = analytic_signal(signals)  # Apply Hilbert transform to the entire set of signals
    print("H::::")
    # display(H)

    # # Compute DTF for the matrix of signals
    dtf_matrix = dtf(H)

    # return dtf_matrix

band_names = ["delta", "theta", "alpha", "beta", "gamma"]
for band in band_names:
    # Get channel names as a list. Drop columns that are not electrode channels
    columns_to_drop = ['patient_id', 'time']
    patients = df_bands_rest[band]['patient_id'].unique().tolist()
    print(patients);
    for pt in patients:
        df_pt = df_bands_rest[band].loc[df_bands_rest[band]['patient_id'] == pt].drop(columns=columns_to_drop).head(1000)
        channel_names = list(df_pt.head())
        print(f"channel_names:::{channel_names}")
        num_channels = len(channel_names)
        display(df_pt)
        # Compute DTF for all pairs of channels
        dtf_matrix = np.zeros((num_channels, num_channels))

        compute_dtf_matrix(df_pt)