### Overview - Graphing brain connectivity in schizophrenia from EEG data - Create DTF Graphs and Gold Table

EEG analysis was carried out using:
1. the raw EEG data, 
as well as the re-referenced data: 
2. the Average Reference Method and
3. the Zero Reference Method.
This allowed us to explore how the choice of reference electrode impacts connectivity outcomes.

EEG data were analyzed using three connectivity methods: Phase-Locking Value (PLV), Phase-Lag Index (PLI), and Directed Transfer Function (DTF), and statistical indices based on graph theory. 

##### In this notebook we will:
  * Graph analysis of EEG data measuring connectivity using three connectivity measures:
    * Directed Transfer Function (DTF)
    * Phase-Locking Value (PLV)
    * Phase-Lag Index (PLI)
##### This Notebook will use Directed Transfer Function (DTF)

###### We need to convert the REST Reference Method PySpark Dataframe to a Pandas Dataframe so we can use with the scipy, mne and numpy packages

In [0]:
# Load the data from the Butterworth Filtered Data REST Tables
band_names = ["delta","theta", "alpha", "beta", "gamma"]

df_bands_rest = {}
# Create Pandas DataFrames
for band in band_names:
    df_bands_rest[band] = spark.sql(f"SELECT * FROM main.solution_accelerator.butter_rest_{band}_gold ORDER BY time ASC").toPandas()

display(df_bands_rest[band])

###### We need to convert the Average Reference Method PySpark Dataframe to a Pandas Dataframe so we can use with the scipy, mne and numpy packages

In [0]:
# Create PySpark DataFrames
df_bands_avg = {}
for band in band_names:
    df_bands_avg[band] = spark.sql(f"SELECT * FROM main.solution_accelerator.butter_avg_{band}_gold WHERE patient_id in ('s11', 'h13') ORDER BY time ASC").toPandas()
display(df_bands_avg[band])


##### Graph analysis of EEG data measuring connectivity using Directed Transfer Function (DTF) connectivity measure

In [0]:
pip install mne

##### Directed Transfer Function (DTF)
Directed Transfer Function (DTF) is a frequency-domain measure derived from multivariate autoregressive (MVAR) modeling of EEG signals. It estimates the directed influence or connectivity between different brain regions in the frequency domain.

In [0]:
# Helper library with many built-in functions
%pip install mne

In [0]:
import mne
import numpy as np
from scipy.signal import coherence, csd
from numpy.linalg import inv

def compute_dtf(data, sfreq, freq_band, n_channels):
    """
    Compute Directed Transfer Function (DTF) for EEG data.
    
    Parameters:
    data : np.ndarray
        Multichannel time series data (channels x time points).
    sfreq : float
        Sampling frequency of the data.
    freq_band : tuple
        Frequency band for DTF calculation (e.g., (8, 12) for alpha band).
    n_channels : int
        Number of channels.
        
    Returns:
    dtf : np.ndarray
        Directed Transfer Function matrix (channels x channels x frequencies).
    """
    n_times = data.shape[1]
    freqs = np.fft.rfftfreq(n_times, 1/sfreq)
    freq_mask = (freqs >= freq_band[0]) & (freqs <= freq_band[1])

    # Compute Cross-Spectral Density (CSD) matrix
    csd_matrix = np.zeros((n_channels, n_channels, len(freqs)), dtype=complex)
    for i in range(n_channels):
        for j in range(n_channels):
            f, Pxy = csd(data[i], data[j], sfreq, nperseg=n_times)
            csd_matrix[i, j, :] = Pxy

    # Compute DTF
    dtf = np.zeros((n_channels, n_channels, len(freqs)))
    for f_idx, freq in enumerate(freqs):
        if freq_mask[f_idx]:
            H = np.zeros((n_channels, n_channels), dtype=complex)
            for i in range(n_channels):
                for j in range(n_channels):
                    H[i, j] = csd_matrix[i, j, f_idx]

            H_inv = inv(H)
            for i in range(n_channels):
                for j in range(n_channels):
                    dtf[i, j, f_idx] = np.abs(H_inv[i, j])**2 / np.sum(np.abs(H_inv[i, :])**2)
    return dtf

# Define the sampling frequency (in Hz)
sampling_freq = 250 

# Define frequency bands
frequency_bands = {
    'delta': (2, 4),
    'theta': (4.5, 7.5),
    'alpha': (8, 12.5),
    'beta': (13, 30),
    'gamma': (30, 45)
}

times: []
channel_names: []

# Transpose data for each band and store in a dictionary
transposed_data = {}
for band in band_names:
    if band['patient_id'] == 's11' or band['patient_id'] == 'h13':
        # Get the times
        times = df_bands_rest[band]['time'].values.tolist()

        # Get channel names as a list. Drop columns that are not electrode channels
        columns_to_drop = ['patient_id', 'time']
        channel_names = df_bands_rest[band].columns.drop(*columns_to_drop).tolist()

        # Extract data for the current band and transpose
        rest_data_T= df_bands_rest[band][channel_names].values.T  # Transpose to get shape (n_channels, n_times)

        # Store transposed data in the dictionary
        transposed_data[band] = rest_data_T

        print(f"LEN of channels::{len(channel_names)}")
        print(f"frequency_bands.items()::{frequency_bands.items()}")

        # Compute DTF, `data` with shape (channels x time points)
        for band_name, freq_range in frequency_bands.items():
            dtf = compute_dtf(data=transposed_data[band], sfreq=sampling_freq, freq_band=freq_range, n_channels=len(channel_names))
            print(dtf[-1])