### Overview - Graphing brain connectivity in schizophrenia from EEG data - Create PLI Graphs and Gold Table

EEG analysis was carried out using:
1. the raw EEG data, 
as well as the re-referenced data: 
2. the Average Reference Method and
3. the Zero Reference Method.
This allowed us to explore how the choice of reference electrode impacts connectivity outcomes.

EEG data were analyzed using three connectivity methods: Phase-Locking Value (PLV), Phase-Lag Index (PLI), and Directed Transfer Function (DTF), and statistical indices based on graph theory. 

##### In this notebook we will:
  * Graph analysis of EEG data measuring connectivity using three connectivity measures:
    * Directed Transfer Function (DTF)
    * Phase-Locking Value (PLV)
    * Phase-Lag Index (PLI)
##### This Notebook will use Phase-Locking Value (PLV)

In [0]:
# Load the data from the Butterworth Filtered Data REST Tables
#pt_to_display = ['s11', 'h11'] if the gold tables only include patients of interest no need to select here
band_names = ["delta","theta", "alpha", "beta", "gamma"]

df_bands_rest = {}
# Create Pandas DataFrames
for band in band_names:
    df_bands_rest[band] = spark.sql(f"SELECT * FROM main.solution_accelerator.butter_rest_{band}_gold ORDER BY time ASC")

display(df_bands_rest["delta"])
unique_patient_ids = df_bands_rest[band].select("patient_id").distinct().toPandas()['patient_id'].values
display(unique_patient_ids)

In [0]:
!pip install networkx

In [0]:
from scipy.signal import hilbert
import numpy as np
import pandas as pd
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType

# define a couple functions we will need
def instantaneousPhase(sig):
    h = hilbert(sig)
    return np.angle(h)

def pli(sig1, sig2):
    instPhase1 = instantaneousPhase(sig1)
    instPhase2 = instantaneousPhase(sig2)
    phaseDiff = instPhase1 - instPhase2
    phaseDiff = (phaseDiff + np.pi) % (2*np.pi) - np.pi
    pli = abs(np.mean(np.sign(phaseDiff)))
    return pli

def adj_matrix(pd_df):
    channels = pd_df.columns[1:20]
    adjMatrix = np.zeros((len(channels), len(channels)))

    for i in range(len(channels)):
        for j in range(len(channels)):
            sig1 = pd_df[channels[i]].values
            sig2 = pd_df[channels[j]].values
            thispli = pli(sig1, sig2)
            adjMatrix[i, j] = thispli

    return adjMatrix
    


adjMatrices = {}

for band in band_names:
    df = df_bands_rest[band]
    unique_patient_ids = df.select("patient_id").distinct().toPandas()['patient_id'].values
    display(unique_patient_ids)
    adjMatrices[band] = []
    for ptid in unique_patient_ids:
        display(ptid)
        pd_df = df.filter(df.patient_id == ptid).toPandas()
        adjMatrix = adj_matrix(pd_df)
        adjMatrices[band].append(adjMatrix)

In [0]:
import matplotlib.pyplot as plt
def plot_heatmap(adjMatrix, ptid, bandname):
    # Plot the heatmap
    im = plt.matshow(adjMatrix, cmap='Spectral_r')

    # Add a colorbar to show the scale
    plt.colorbar(im,fraction=0.046, pad=0.04)

    # Remove the axis numbers
    plt.xticks([])
    plt.yticks([])
       
    plt.title(f"PtID: {ptid}, Band: {bandname}")

    # Show the plot
    plt.show()


for band in band_names:
    adjMatrixList = adjMatrices[band]
    plot_heatmap(adjMatrixList[0], 'S11', band)
    plot_heatmap(adjMatrixList[1], "H13", band)

In [0]:
import matplotlib.pyplot as plt
import networkx as nx
def plot_graph(adjMatrix, ptid, bandname):
    pos = {
    'Fp1': (-3.15, 6.85),
    'F7': (-8.10, 4.17),
    'F3': (-4.05, 3.83), 
    'Fz': (0, 3.6),
    'F4': (4.05, 3.83),
    'F8': (8.10, 4.17),
    'T3': (-10.1,0), 
    'C3': (-5,0),
    'Cz': (0,0),
    'C4': (5,0),
    'T4': (10.1,0), 
    'T5': (-8.10, -4.17), 
    'P3': (-4.05, -3.83),
    'Pz': (0, -3.6),
    'P4': (4.05, -3.83),
    'T6': (8.10, -4.17),
    'O1': (-3.15, -6.85),
    'O2': (3.15, -6.85),
    'Fp2': (3.15, 6.85)
    }
    nodesList = list(pos.keys())
    # Create a new graph
    G = nx.Graph()

    # Add nodes with positions
    for node, position in pos.items():
        G.add_node(node, pos=position)

    # Add edges between nodes (optional, depending on your graph's needs)
    # Example: G.add_edge('Cz', 'C2h')
    # Add your edges here based on your graph's structure

    # loop through the connections, assume undirected graph for now and ignore self connections
    numNodes = adjMatrix.shape[0]
    adjMatrixMean = np.mean(adjMatrix)
    adjMatrixSigma = np.std(adjMatrix)


    for i in range(numNodes):
        for j in range(i+1, numNodes):
            # get the mean connectivity weight
            if adjMatrix[i, j] > adjMatrixMean + 2*adjMatrixSigma:
                # add an edge here
                G.add_edge(nodesList[i], nodesList[j])
                


    # Draw the graph
    nx.draw(G, pos, with_labels=True, node_size=700, node_color="skyblue", font_size=10)
    plt.title(f"PtID: {ptid}, Band: {bandname}")
    plt.show()

for band in band_names:
    adjMatrixList = adjMatrices[band]
    plot_graph(adjMatrixList[0], 'S11', band)
    plot_graph(adjMatrixList[1], "H13", band)
    #for matrix in adjMatrixList:
        #plot_graph(matrix, )