## AIM: Train Graph Convulation Networks (GCNs) as multiclass classification models for predicting the psychiatric diagnosis based on EEG features

GraphLambda with & without edge attributes for predicting psychiatric diagnosis with statistical features of the power per frequency band per channel group (frontal, central, parietal, occipital) and connectivity features per frequency band per channel group (l/m/r; frontal, central, ). Statistical features are calculated with EC data, EO data and a with the ratio between EC divided by EO.

In [1]:
import numpy as np
import pandas as pd
import sklearn
import pickle
import mne
import os
import matplotlib.pyplot as plt
from mne.time_frequency import tfr_multitaper
import networkx as nx

%matplotlib inline

# prevent extensive logging
mne.set_log_level('WARNING')

## Loading in feature data

#### stat features + connectivity features

In [3]:
df_stat_conn_features = pd.read_pickle(r'D:\Documents\RU\Master_Neurobiology\Internship_jaar_2\Project\TD-BRAIN\TD-BRAIN_extracted_features\df_selected_stat_conn_features.pkl')
df_stat_conn_features = df_stat_conn_features.dropna(subset=['diagnosis'])
df_stat_conn_features.sample(7)

Unnamed: 0,ID,epoch,diagnosis,EC_central_delta_std,EC_central_delta_mean,EC_central_delta_median,EC_delta_l_frontal-m_frontal,EC_delta_l_frontal-r_frontal,EC_delta_l_frontal-l_central,EC_delta_l_frontal-m_central,...,ratio_gamma_m_central-r_central,ratio_gamma_m_central-l_posterior,ratio_gamma_m_central-m_posterior,ratio_gamma_m_central-r_posterior,ratio_gamma_r_central-l_posterior,ratio_gamma_r_central-m_posterior,ratio_gamma_r_central-r_posterior,ratio_gamma_l_posterior-m_posterior,ratio_gamma_l_posterior-r_posterior,ratio_gamma_m_posterior-r_posterior
1407,sub-88044141,4,ADHD,0.412439,-0.168952,-0.15843,0.79975,0.727261,0.812249,0.789707,...,0.924987,0.866749,0.896587,0.886406,0.840558,0.867532,0.887689,0.885101,0.856176,0.901534
2056,sub-88059977,5,OCD,0.388174,-0.150207,-0.144225,0.800401,0.805509,0.732293,0.740633,...,1.003223,0.966227,0.987201,0.978696,0.97194,0.989077,0.97864,0.990548,0.986468,0.989931
1386,sub-88042969,7,MDD,0.391372,-0.156194,-0.155787,0.893121,0.761904,0.885288,0.844061,...,1.037316,1.005055,1.005988,1.011161,1.066669,1.052438,1.007116,0.994886,1.02106,1.016479
1747,sub-88053497,8,OCD,0.375859,-0.140785,-0.131386,0.841164,0.778846,0.749598,0.729198,...,0.978731,0.950695,0.971756,0.961029,0.912759,0.938486,0.959458,0.967874,0.926915,0.955225
1210,sub-88029425,11,ADHD,0.346542,-0.121467,-0.115682,0.86361,0.79058,0.802815,0.807824,...,1.010778,1.020124,1.021165,1.013914,1.021832,1.024234,1.020812,1.010803,1.007805,1.014421
1467,sub-88046437,4,MDD,0.388512,-0.15714,-0.159513,0.84943,0.812235,0.782721,0.771285,...,1.009528,0.986628,0.992984,0.978494,0.989447,0.990907,0.989985,0.990729,0.997684,0.991289
2682,sub-88076717,7,OCD,0.386303,-0.152206,-0.15984,0.845244,0.758196,0.823563,0.732196,...,1.051508,1.074832,1.051253,1.067357,1.11182,1.083585,1.088771,1.046488,1.061953,1.052174


In [4]:
df_stat_conn_features['diagnosis'].value_counts()

diagnosis
SMC        540
HEALTHY    540
MDD        540
ADHD       540
OCD        540
Name: count, dtype: int64

In [5]:
# create 3 feature sets [EC, EO, ratio] with column ['ID', 'age', 'gender', 'diagnosis', 'epoch']
df_stat_conn_features_ec = df_stat_conn_features[df_stat_conn_features.columns[~(df_stat_conn_features.columns.str.startswith('EO') | df_stat_conn_features.columns.str.startswith('ratio'))]]
df_stat_conn_features_eo = df_stat_conn_features[df_stat_conn_features.columns[~(df_stat_conn_features.columns.str.startswith('EC') | df_stat_conn_features.columns.str.startswith('ratio'))]]
df_stat_conn_features_ratio = df_stat_conn_features[df_stat_conn_features.columns[~(df_stat_conn_features.columns.str.startswith('EC') | df_stat_conn_features.columns.str.startswith('EO'))]]

In [6]:
print(df_stat_conn_features_ratio.shape)
print(df_stat_conn_features_ec.shape)
print(df_stat_conn_features_eo.shape)

(2700, 178)
(2700, 241)
(2700, 248)


In [83]:
# obtain graphs from selected synchrony features

def get_graph_from_features(df, subject_id):
    """
    Get the graph from the selected synchrony features
    :param df: the dataframe
    :return: edge_index, edge_attr
    """

    # get the features
    df = df[(df['ID'] == subject_id) & (df['epoch'] == 1)] # for now select just 1 epoch
    excluded_columns = ['ID', 'epoch', 'diagnosis','mean', 'median', 'std', 'skew', 'kurt'] # exlude demographic and statistical features
    pattern = '|'.join(excluded_columns)  # create a pattern string
    features = df[df.columns[~df.columns.str.contains(pattern)]]
    features.columns = features.columns.str.replace('EC_', '').str.replace('EO_', '').str.replace('ratio_', '').tolist() # remove prefix

    for band in ['delta', 'theta', 'alpha', 'beta', 'gamma']:
        print('\n')
        print(band)
        rows_list = []
        features_band = features[features.columns[features.columns.str.contains(band)]]
        features_band.columns = features_band.columns.str.replace(band + '_', '').tolist() # remove band prefix
        for column, value in features_band.items():
            ch_group1, ch_group2 = column.split('-')
            phase_diff = value.values[0]
            row_dict = {'phase_diff': phase_diff, 'ch_group1': ch_group1, 'ch_group2': ch_group2}
            rows_list.append(row_dict)
        syncro_matrix_df = pd.DataFrame(rows_list)

        syncro_graph = nx.from_pandas_edgelist(syncro_matrix_df, source='ch_group1', target='ch_group2', edge_attr=['phase_diff'])
        edge_list = nx.to_pandas_edgelist(syncro_graph)
        # print(f'{edge_list = }')
        edge_index = edge_list.iloc[:,0:2].values.T # shape (2, num_edges), edge_index stays the same for all bands right?
        # print(f'{edge_index = }')
        print(f'{edge_index.shape = }')
        edge_attr = edge_list.iloc[:,2].values#.reshape(-1,1)
        print(f'{edge_attr = }')
        print(f'{edge_attr.shape = }')

    return 

# features = get_graph_from_features(df_stat_conn_features_ratio, 'sub-88044141')
# print(features)



delta
edge_index.shape = (2, 8)
edge_attr = array([1.02057632, 1.0078158 , 0.97134215, 0.98323892, 0.98620059,
       0.99926581, 0.99666485, 1.05859544])
edge_attr.shape = (8,)


theta
edge_index.shape = (2, 31)
edge_attr = array([1.00735298, 1.05178027, 0.98222608, 0.99164144, 1.03627727,
       1.04593814, 1.01076757, 0.98858614, 1.0343399 , 1.03515944,
       1.10627604, 1.07613953, 1.06580834, 1.07564848, 1.1454665 ,
       1.12833048, 1.07383915, 1.0022602 , 1.02391903, 1.03804833,
       0.9679595 , 1.03182307, 1.03171298, 1.03968656, 1.03266889,
       1.0845724 , 1.10009132, 1.05330528, 0.98686982, 0.97486388,
       1.00963413])
edge_attr.shape = (31,)


alpha
edge_index.shape = (2, 35)
edge_attr = array([1.13610724, 1.23429072, 1.05817182, 1.12516353, 1.27348992,
       0.94255709, 0.95354573, 0.9468783 , 1.10289816, 1.16894478,
       1.07959743, 1.16173502, 1.02904965, 0.96648991, 0.9182226 ,
       1.21280858, 1.1216875 , 1.0699995 , 1.06330886, 0.98712259,
       0.888