### Overview - Graphing brain connectivity in schizophrenia from EEG data - Create DTF Graphs and Gold Table

EEG analysis was carried out using:
1. the raw EEG data, 
as well as the re-referenced data: 
2. the Average Reference Method and
3. the Zero Reference Method.
This allowed us to explore how the choice of reference electrode impacts connectivity outcomes.

EEG data were analyzed using three connectivity methods: Phase-Locking Value (PLV), Phase-Lag Index (PLI), and Directed Transfer Function (DTF), and statistical indices based on graph theory. 

##### In this notebook we will:
  * Graph analysis of EEG data measuring connectivity using three connectivity measures:
    * Directed Transfer Function (DTF)
    * Phase-Locking Value (PLV)
    * Phase-Lag Index (PLI)
##### This Notebook will use Directed Transfer Function (DTF)

###### We need to convert the REST Reference Method PySpark Dataframe to a Pandas Dataframe so we can use with the scipy, mne and numpy packages

In [0]:
# Load the data from the Butterworth Filtered Data REST Tables
band_names = ["delta", "theta", "alpha", "beta", "gamma"]

df_bands_rest = {}
# Create Pandas DataFrames
for band in band_names:
    df_bands_rest[band] = spark.sql(f"SELECT * FROM main.solution_accelerator.butter_rest_{band}_gold ORDER BY time ASC").toPandas()
    # Get distinct values from the 'patient_id' column
    patient_ids = df_bands_rest[band]['patient_id'].unique()
    print(f"patient_ids:::{patient_ids}")
# display(df_bands_rest[band])

###### We need to convert the Average Reference Method PySpark Dataframe to a Pandas Dataframe so we can use with the scipy, mne and numpy packages

In [0]:
# Create PySpark DataFrames
band_names = ["delta", "theta", "alpha", "beta", "gamma"]

df_bands_avg = {}

for band in band_names:
    df_bands_avg[band] = spark.sql(f"SELECT * FROM main.solution_accelerator.butter_avg_{band}_gold ORDER BY time ASC").toPandas()
    # Get distinct values from the 'patient_id' column
    patient_ids = df_bands_avg[band]['patient_id'].unique()
    print(f"patient_ids:::{patient_ids}")
    display(df_bands_avg[band])


##### Graph analysis of EEG data measuring connectivity using Directed Transfer Function (DTF) connectivity measure

##### Directed Transfer Function (DTF)
Directed Transfer Function (DTF) is a frequency-domain measure derived from multivariate autoregressive (MVAR) modeling of EEG signals. It estimates the directed influence or connectivity between different brain regions in the frequency domain.

In [0]:
# Helper library with many built-in functions
%pip install mne

In [0]:
import mne

band_names = ["delta", "theta", "alpha", "beta", "gamma"]

# Sampling rate in Hz
sfreq = 250

mne_raw_all = {}

# Create Pandas DataFrames
for band in band_names:
    mne_raw_all[band] = {}
    
    # Get channel names , extract patient_id column
    ch_names = [c for c in df_bands_rest[band].columns if c not in ['patient_id', 'time']]
    # print(f"ch_names:::{ch_names}")

    # Extract patient_id column
    pt_names = list(df_bands_rest[band]['patient_id'].unique())
    # print(f"patient_names:::{pt_names}")

    for pt in pt_names:
        print("PATIENT_ID::", pt)
        df_pt_data = df_bands_rest[band].loc[df_bands_rest[band]['patient_id'] == pt]
        df_pt_data = df_pt_data.drop(columns=['patient_id', 'time'])    
        # print("LEN::", len(df_pt_data.index))
        
        # Convert Pandas Dataframe to Numpy Array for each patient
        np_pt_data = df_pt_data.to_numpy() 

        # Create an info structure needed by MNE
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
        
        # Create the MNE Raw object
        mne_raw_pt = mne.io.RawArray(np_pt_data.T, info)
        
        # The mne raw data object gives us time, assess it as `data, times = raw[:]`  
        # Channel mapping
        mne_raw_all[band][pt] = mne_raw_pt.set_montage('standard_1020')
        
        # # Plot the data so we can compare graphs to reference methods later
        # print(f"Patient ID: {pt}")
        # mne_raw_all[band][pt] .plot(scalings=dict(eeg=50), title=(f"Patient ID: {pt}"), start=150, duration=100)
        # print(f"Patient ID: {pt}")
        # mne_raw_all[band][pt].plot_sensors(ch_type="eeg", title=(f"Patient ID: {pt}"))
        # print(f"Patient ID: {pt}")
        # spectrum = mne_raw_all[band][pt].compute_psd().plot(average=True, picks="data", exclude="bads", amplitude=False)
        
# Now we have our MNE Raw objects and are ready for further analysis


In [0]:
import numpy as np
import mne
from scipy.linalg import inv
import scot

band_names = ["delta", "theta", "alpha", "beta", "gamma"]

def compute_dtf(data, n_channels, n_times, order=5):
    """
    Compute the Directed Transfer Function (DTF) from multichannel time series data.

    Parameters:
    data : ndarray
        Multichannel time series data of shape (n_channels, n_times).
    order : int
        Order of the MVAR model.

    Returns:
    dtf : ndarray
        DTF matrix of shape (n_channels, n_channels, n_frequencies).
    """

    # n_channels, n_times = data.shape
    n_frequencies = n_times // 2 + 1  # Number of positive frequencies

    # Fit MVAR model
    X = np.zeros((n_times - order, n_channels * order))
    Y = np.zeros((n_times - order, n_channels))
    for t in range(order, n_times):
        X[t - order] = np.hstack([data[:, t - k] for k in range(1, order + 1)])
        Y[t - order] = data[:, t]

    A = np.linalg.lstsq(X, Y, rcond=None)[0].T
    A = A.reshape(n_channels, n_channels, order)

    # Compute DTF
    H = np.zeros((n_channels, n_channels, n_frequencies), dtype=complex)
    dtf = np.zeros((n_channels, n_channels, n_frequencies))

    for f in range(n_frequencies):
        omega = 2 * np.pi * f / n_times
        Af = np.sum([A[:, :, k] * np.exp(-1j * omega * (k + 1)) for k in range(order)], axis=0)
        H[:, :, f] = inv(np.eye(n_channels) - Af)

    for i in range(n_channels):
        for j in range(n_channels):
            dtf[i, j, :] = np.abs(H[i, j, :]) / np.sqrt(np.sum(np.abs(H[i, :, :])**2, axis=0))

    return dtf

def calculate_dtf(data_intervals, steps, channels, sample_rate, bands, flag):
    num_bands = sum(bands)
    intervals = (len(steps)) - flag
    matrix = np.zeros(shape=((intervals * num_bands), channels, channels))
    start, stop = 0, channels
    
    ws = scot.Workspace({'model_order': channels - 5}, reducedim = 'no_pca', nfft= int(sample_rate/2), fs = sample_rate)
    
    f = np.arange(0, int(sample_rate/2))
    
    #Loop over the number of intervals
    for k in range(intervals):
        #If there is more than one interval, the new start is the last stop and we calculate the new stop with the number of channels. 
        if k!=0:
            start = stop
            stop+= channels
            
        data = []
        for h in range(start, stop):
            data.append(data_intervals[h])
        
        ws.set_data(data)
        ws.do_mvarica()
        ws.fit_var()
        results = ws.get_connectivity('DTF')
        #Loop over 
        for x,i in enumerate(range(start, stop)):
            for y,j in enumerate(range(start, stop)):
                delta, theta, alpha, beta, gamma = frequency_bands(f, results[x][y])
                r=0
                for z, item in enumerate ([delta, theta, alpha, beta, gamma]):
                    if bands[z]:
                        if (len(item)!= 0):
                            matrix[(k * num_bands) + r][x,y] = item.mean()
                        else:
                            matrix[(k * num_bands) + r][x,y] = 0
                        r+=1                  
    return matrix

pt_all = ['s11', 'h13']
dtf_matrix_dict = {}
for band in band_names:
    for pt in mne_raw_all[band]:

        # Get channel names , extract patient_id column
        ch_names = [c for c in df_bands_rest[band].columns if c not in ['patient_id', 'time']]
        # print(f"ch_names:::{ch_names}")

        # Preprocess the data (e.g., filtering and epoching)
        # mne_raw_all[band][pt].filter(1.0, 30.0, method='iir')  # Band-pass filter between 1 and 30 Hz
  
        # events = mne.find_events(mne_raw_all[band][pt], stim_channel=ch_names, min_duration=1)

        # epochs = mne.Epochs(mne_raw_all[band][pt], events, event_id=1, tmin=-0.2, tmax=0.5, preload=True)

        # Get data from epochs (shape: n_epochs, n_channels, n_times)
        # data = epochs.get_data()

        # # Average across epochs to get (n_channels, n_times)
        # data_avg = np.mean(mne_raw_all[band][pt], axis=0)
        # data_avg = np.expand_dims(data_avg, axis=0)  # Add an extra dimension
        # data_avg = np.transpose(data_avg, (1, 0))  # Transpose to shape (n_channels, n_times)

        n_channels = len(mne_raw_all[band][pt].info['ch_names'])
        n_times = mne_raw_all[band][pt].n_times

        # Compute DTF
        dtf_matrix_dict[mne_raw_all[band][pt]] = compute_dtf(data=mne_raw_all[band][pt], n_channels=n_channels, n_times=n_times)
        # print("DTF matrix shape:", dtf_matrix.shape)
       
        # # Plot the DTF matrix at a specific frequency (e.g., frequency index 10)
        # import matplotlib.pyplot as plt
        # plt.imshow(dtf_matrix[:, :, 10], aspect='auto', origin='lower')
        # plt.title('Directed Transfer Function (DTF) at frequency index 10')
        # plt.xlabel('Channel')
        # plt.ylabel('Channel')
        # plt.colorbar(label='DTF value')
        # plt.show()

        # # Plot the DTF matrix for a specific frequency band (e.g., alpha band 8-12 Hz)
        # import matplotlib.pyplot as plt
        # alpha_band = (8, 12)
        # alpha_idx = np.where((fmin <= alpha_band[0]) & (fmax >= alpha_band[1]))[0]
        # plt.imshow(np.mean(dtf_matrix[:, :, alpha_idx], axis=-1), aspect='auto', origin='lower')
        # plt.title('Directed Transfer Function (DTF) - Alpha Band')
        # plt.xlabel('Channel')
        # plt.ylabel('Channel')
        # plt.colorbar(label='DTF value')
        # plt.show()


In [0]:
# import mne
# import numpy as np
# from scipy.signal import coherence, csd
# from numpy.linalg import inv

# def compute_dtf(data, sfreq, freq_band, n_channels):
#     """
#     Compute Directed Transfer Function (DTF) for EEG data.
    
#     Parameters:
#     data : np.ndarray
#         Multichannel time series data (channels x time points).
#     sfreq : float
#         Sampling frequency of the data.
#     freq_band : tuple
#         Frequency band for DTF calculation (e.g., (8, 12) for alpha band).
#     n_channels : int
#         Number of channels.
        
#     Returns:
#     dtf : np.ndarray
#         Directed Transfer Function matrix (channels x channels x frequencies).
#     """
#     n_times = data.shape[1]
#     freqs = np.fft.rfftfreq(n_times, 1/sfreq)
#     freq_mask = (freqs >= freq_band[0]) & (freqs <= freq_band[1])

#     # Compute Cross-Spectral Density (CSD) matrix
#     csd_matrix = np.zeros((n_channels, n_channels, len(freqs)), dtype=complex)
#     for i in range(n_channels):
#         for j in range(n_channels):
#             f, Pxy = csd(data[i], data[j], sfreq, nperseg=n_times)
#             csd_matrix[i, j, :] = Pxy

#     # Compute DTF
#     dtf = np.zeros((n_channels, n_channels, len(freqs)))
#     for f_idx, freq in enumerate(freqs):
#         if freq_mask[f_idx]:
#             H = np.zeros((n_channels, n_channels), dtype=complex)
#             for i in range(n_channels):
#                 for j in range(n_channels):
#                     H[i, j] = csd_matrix[i, j, f_idx]

#             H_inv = inv(H)
#             for i in range(n_channels):
#                 for j in range(n_channels):
#                     dtf[i, j, f_idx] = np.abs(H_inv[i, j])**2 / np.sum(np.abs(H_inv[i, :])**2)
#     return dtf

# band_names = ["delta", "theta", "alpha", "beta", "gamma"]

# # Define the sampling frequency (in Hz)
# sampling_freq = 250 

# # Define frequency bands
# frequency_bands = {
#     'delta': (2, 4),
#     'theta': (4.5, 7.5),
#     'alpha': (8, 12.5),
#     'beta': (13, 30),
#     'gamma': (30, 45)
# }

# times: []
# channel_names: []

# # Transpose data for each band and store in a dictionary
# transposed_data = {}
# for band in band_names:
#     # Get the times
#     times = df_bands_rest[band]['time'].values.tolist()

#     # Get channel names as a list. Drop columns that are not electrode channels
#     columns_to_drop = ['patient_id', 'time']
#     channel_names = df_bands_rest[band].drop(columns=columns_to_drop).columns.tolist()
#     print(f"channel_names:::{channel_names}")

#     # Extract data for the current band and transpose
#     rest_data_T= df_bands_rest[band][channel_names].values.T  # Transpose to get shape (n_channels, n_times)

#     # Store transposed data in the dictionary
#     transposed_data[band] = rest_data_T

#     print(f"LEN of channels:::{len(channel_names)}")
#     print(f"list(frequency_bands.values()):::{list(frequency_bands.values())}")

#     # Compute DTF, `data` with shape (channels x time points)
#     for freq_range in list(frequency_bands.values()):
#         dtf = compute_dtf(data=transposed_data[band], sfreq=sampling_freq, freq_band=freq_range, n_channels=len(channel_names))
#         print(dtf[-1])

In [0]:
# import numpy as np
# import pandas as pd
# from scipy.signal import hilbert

# # Function to compute analytic signal using Hilbert transform
# def analytic_signal(sig):
#     return hilbert(sig)

# # Function to compute covariance matrix
# def covariance_matrix(H):
#     return np.cov(H)

# # Function to compute Directed Transfer Function (DTF)
# def dtf(H):
#     C = covariance_matrix(H)
#     Pxx = np.diag(C)
#     D = np.dot(np.dot(H, np.linalg.inv(C)), H.T)
#     D /= Pxx[:, None]
#     D = np.abs(D)
#     return D

# def compute_dtf_matrix(pd_df):
#     """
#     Compute the Directed Transfer Function (DTF) adjacency matrix for multiple EEG channels.

#     Parameters:
#     pd_df (pandas DataFrame): DataFrame containing EEG signals as columns.

#     Returns:
#     numpy array: Adjacency matrix where entry (i, j) represents the DTF from channel j to channel i.
#     """

#     # Prepare a 2D array where each column is a signal
#     signals = pd_df.values
#     display(signals)
#     H = analytic_signal(signals)  # Apply Hilbert transform to the entire set of signals
#     print("H::::")
#     # display(H)

#     # # Compute DTF for the matrix of signals
#     dtf_matrix = dtf(H)

#     # return dtf_matrix

# band_names = ["delta", "theta", "alpha", "beta", "gamma"]
# for band in band_names:
#     # Get channel names as a list. Drop columns that are not electrode channels
#     columns_to_drop = ['patient_id', 'time']
#     patients = df_bands_rest[band]['patient_id'].unique().tolist()
#     print(patients);
#     for pt in patients:
#         df_pt = df_bands_rest[band].loc[df_bands_rest[band]['patient_id'] == pt].drop(columns=columns_to_drop).head(1000)
#         channel_names = list(df_pt.head())
#         print(f"channel_names:::{channel_names}")
#         num_channels = len(channel_names)
#         display(df_pt)
#         # Compute DTF for all pairs of channels
#         dtf_matrix = np.zeros((num_channels, num_channels))

#         compute_dtf_matrix(df_pt)