In [None]:
import sys
from glob2 import glob

import numpy as np
import pandas as pd # dataframes, tables 
import seaborn as sns # plotting
import matplotlib.pyplot as plt # plotting

import mne
from mne.io import read_raw_edf
import networkx as nx

def read_edf_to_pandas(edf_filename, select_channels = True):
    """ Reads data from an edf file to a Pandas dataframe.
        Column names are 'channel_labels'.
        
        If 'select_channels=True', then only 19 common channels are
        selected to create the resulting dataframe. The channel names
        will be updated (standardized).
        
        Returns: dataframe, channel labels
    """
    # read edf file
    raw_data = read_raw_edf(edf_filename, verbose=False, preload=False)
    
    if select_channels:
        # the TUEP database has 3 EEG channel configurations:
        # '02_tcp_le', '03_tcp_ar_a', '01_tcp_ar'
        # number of channels and channel names differ within 
        # these configurations to be able to compare the 
        # different EEG readings we need to select channels
        # that are common for all configurations

        # the list of 19 channels (their short labels) that 
        # we will use for analysing EEG data
        channels_to_use = ['FP1', 'FP2', 'F7', 'F3', 'FZ', 'F4', 'F8',
                           'T3', 'C3', 'CZ', 'C4', 'T4', 'T5',
                           'P3', 'PZ', 'P4', 'T6', 'O1', 'O2']
        
        # the function to update channel names from original to new format:
        ch_name_update_func = lambda ch: ch.split(' ')[-1].split('-')[0]

        # renaming the original channel names in one .edf file;
        # the update will be written into the in-memory edf object
        raw_data.rename_channels(mapping=ch_name_update_func)
        
        # check if all required channels are in the edf file
        try:
            assert all([ch in raw_data.info["ch_names"] \
                        for ch in channels_to_use])
        except:
            print('Not all required channels are in the edf file.')
        
        # dataframe with EEG readings from selected channels and with 
        # updated channel names
        df = pd.DataFrame(raw_data.pick_channels(channels_to_use).get_data().T,
            columns=raw_data.pick_channels(channels_to_use).info['ch_names'])
        
        # we need to return correct channel/column names
        channel_labels = channels_to_use # as specified by us: 
        # left-to-right and top-down 
        
    else:
        # get channel names from edf file
        channel_labels = raw_data.info["ch_names"]

        # create a dataframe from
        df = pd.DataFrame(raw_data.get_data().T, columns=channel_labels)

    return df[channel_labels], channel_labels 


def compute_corr_matrix(edf_filename):
    
    # read edf file from filename
    # by default, common channels will be selected and renamed
    df, channel_labels = read_edf_to_pandas(edf_filename)
    
    # calculate the correlation matrix
    corr_matrix = df.corr()
    
    return corr_matrix

def proportional_thresholding(matrix_a, percentile):
    """ Performs proportional thresholding of a functional
        connectivity matrix. If a matrix value is lower than
        the threshold it is set to 0, otherwise to 1.
 
    Args:
        matrix_a (np.ndarray): weighted functional connectivity matrix
        percentile (float): lower cut percentile from 0 to 100 inclusive
 
    Returns:
        nd.ndarray: binary adjacency matrix after thresholding
    """
    # threshold = np.percentile(np.unique(matrix_a), q=percentile)
    threshold = np.percentile(matrix_a, q=percentile)
 
    matrix_b = copy(matrix_a)
    matrix_b[matrix_a < threshold] = 0
    matrix_b[matrix_a >= threshold] = 1
   
    return matrix_b

def plot_correlation_matrix(edf_filename, prop_thresholding=False):
    """ Reads edf file from relative path (ex.
    ../tuh_eeg_epilepsy/edf/*/*/*/*/*/*.edf). 
        Creates a dataframe with all EEG readings from all channels.
        Computes a correlation matrix.
    """

    # calculate the correlation matrix
    corr_matrix = compute_corr_matrix(edf_filename)
    
    # apply proportional thresholding if selected
    if prop_thresholding:
        corr_matrix = proportional_thresholding(corr_matrix,
        percentile=50)

    # plot the heatmap for correlation matrix
    fig, ax = plt.subplots(1,1, figsize=(8,6))
    sns.heatmap(corr_matrix,
                xticklabels=channel_labels, 
                yticklabels=channel_labels,
                cmap= plt.cm.jet,
                ax = ax)
    
    plt.title('Correlation Matrix')
    plt.xlabel('EEG channels')
    plt.ylabel('EEG channels')

In [None]:

def entropy(bins, *X):
    
    # binning of the data
    data, *edges = np.histogramdd(X, bins=bins)
    
    # calculate probabilities
    data = data.astype(float)/data.sum()
    
    # compute H(X,Y,...,Z) = sum(-P(x,y,...,z) ∗ log2(P(x,y,...,z)))
    return np.sum(-data * np.log2(data+sys.float_info.epsilon))


def mutual_information(bins, X, Y):
    
    # compute I(X,Y) = H(X) + H(Y) − H(X,Y)
    
    H_X = entropy(bins, X)
    H_Y = entropy(bins, Y)
    H_XY = entropy(bins, X, Y)
    
    return H_X + H_Y - H_XY

# Compute number of bins using Sturge's rule
def compute_mi_matrix(df):
    """ Compute Mutual Information matrix.
    
        Return: mi_matrix
    """
    n_cols = df.shape[1]
    mi_matrix = np.zeros([n_cols, n_cols])
    
    # Sturge's rule for number of bins
    n_bins = int(1 + 3.322*np.log10(df.shape[0]))
    
    for i in range(n_cols):
        for j in range(n_cols):
            mi = mutual_information(n_bins, df.iloc[:,i],df.iloc[:,j])
            mi_matrix[i,j] = mi
    
    return mi_matrix

    
def compute_normed_mi_matrix(mi_matrix):
    """ Compute normalized version of the given 
    Mutual Information matrix.
    
        Return: normed_mi_matrix
    """
    
    # normalize mi matrix by dividing matrix elements with
    # sqrt of product of respective diagonal elements
    divisor_matrix = np.sqrt(np.diag(mi_matrix)*np.diag(mi_matrix).reshape(-1,1))
    normed_mi_matrix = mi_matrix/divisor_matrix

    return normed_mi_matrix

In [None]:

from scipy.cluster.hierarchy import dendrogram, linkage
def plot_dendogram(matrix, channel_labels):

    fig, ax = plt.subplots(1,1, figsize=(10,7))
    
    linkage_data = linkage(matrix,
                           metric = "euclidean",
                           method = "ward")
    
    dendrogram_plt = dendrogram(linkage_data,
                                labels = channel_labels,
                                ax = ax
                               )
    plt.title('Dendrogram')
    plt.xlabel('EEG channels')
    plt.ylabel('Euclidean distances')
    
    return dendrogram_plt, linkage_data

In [None]:
# the list of channels (their short labels) that we will
# use for analysing EEG data

channels_to_use = ['FP1', 'FP2', 'F7', 'F3', 'FZ', 'F4', 'F8',
                   'T3', 'C3', 'CZ', 'C4', 'T4', 'T5', 'P3',
                   'PZ', 'P4', 'T6', 'O1', 'O2']

# electrode positions for the graph nodes
# based on the international 10-20 system
node_positions = {'FP1': (-0.4, 0.45), 'FP2': (0.4, 0.45),
                  'F7': (-1.25, 0.33), 'F3': (-0.65, 0.255),
                  'FZ': (0.0, 0.225), 'F4': (0.65, 0.255),
                  'F8': (1.25, 0.33), 'T3': (-1.5, 0.0),
                  'C3': (-0.75, 0.0), 'CZ': (0.0, 0.0),
                  'C4': (0.75, 0.0), 'T4': (1.5, 0.0),
                  'T5': (-1.25, -0.33), 'P3': (-0.65, -0.255),
                  'PZ': (0.0, -0.225), 'P4': (0.65, -0.255),
                  'T6': (1.25, -0.33), 'O1': (-0.4, -0.45),
                  'O2': (0.4, -0.45) 
                 }
nodelist = node_positions.keys()

edge_colors = plt.cm.Blues
node_colors = 'red'

def plot_graph(G, ax, add_edge_labels=False):
    plt.figure(figsize=(7,7))

    # plot nodes
    nx.draw_networkx_nodes(G,
                           node_positions,
                           nodelist=nodelist,
                           node_size=800, 
                           node_color=node_colors,
                           alpha=0.7,
                           ax=ax)

    # plot node labels
    nx.draw_networkx_labels(G,
                            node_positions,
                            labels=dict(zip(nodelist, nodelist)),
                            font_color='white',
                            font_size=14,
                            font_weight='bold',
                            ax=ax
                           )
    
    edge_weights = nx.get_edge_attributes(G,'weight')
    
    # plot edges
    plt_graph_obj = nx.draw_networkx_edges(
        G, node_positions,
        edge_cmap=edge_colors,
        width=5,
        edge_color=[G[u][v]['weight'] for u, v in G.edges],
        ax=ax
    )
    
    # plot edge labels (weights)
    if add_edge_labels:
        edge_labels = nx.draw_networkx_edge_labels(
            G, node_positions,
            edge_labels=edge_weights,
            label_pos=0.5,
            ax=ax
        )
    
    # plt.box(False)
    # ax.axis('off')
    
    return edge_weights