In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

import scipy.signal as sig
from scipy.stats import skew, kurtosis
from scipy.integrate import simpson as simps 

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

In [2]:
def load_data(subject='A01', flag='T', n_classes=4):
    load_data = np.load(f"./data_save/2a/{subject}{flag}.npz")
    data_np = load_data['data']
    label_np = load_data['label']
    
    if n_classes ==2:
        firstClass = np.unique(label_np)[0]  
        secondClass = firstClass + 1   
        mask = (label_np == firstClass) | (label_np == secondClass)
        data_np = data_np[mask]
        label_np = label_np[mask]
        
        label_np = np.where(label_np == firstClass, 0, 1)
    
    data = torch.from_numpy(data_np).float()
    label = torch.from_numpy(label_np).long()
    label = label - label.min().item()
    
    return data_np, data, label

In [3]:
'''plv'''
def get_plv(data, label, norm=True, plot=False, save=False):
    plv = compute_plv(data)
    assert len(plv) == len(label)
    if norm==True:
        plv = z_score_norm(plv)
    print("Datas shape:", plv.shape)      
    print("Labels shape:", label.shape)

    if plot==True:
        plt.figure(figsize=(8, 6))
        heat_map = plv[0]
        plt.imshow(heat_map, cmap='viridis', vmin=0, vmax=1)
        plt.colorbar(label='PLV Value')
        plt.title("Phase Locking Value (PLV) Matrix")

        
        ax = plt.gca() 
        ax.set_xticks(range(0, heat_map.shape[1], 3))
        ax.set_yticks(range(0, heat_map.shape[0], 3))
        ax.set_xticklabels(range(0, heat_map.shape[1], 3), rotation=45) 
        ax.set_yticklabels(range(0, heat_map.shape[0], 3))
       
        plt.xlabel("Electrode Index")
        plt.ylabel("Electrode Index")
        
        if save:
            plt.savefig('plv_heatmap.png')
            
        plt.show()
  
    return plv

In [4]:
'''utils'''
def aggregate_eeg_data(data_np,band): 
    assert data_np.ndim ==3
    data_np = data_np[..., np.newaxis] 
    data_np = np.copy(data_np) * np.ones(len(band)-1)
    return data_np

def bandpass(data: np.ndarray, edges: list[float], sample_rate: float, poles: int = 5):
    sos = sig.butter(poles, edges, 'bandpass', fs=sample_rate, output='sos')
    filtered_data = sig.sosfiltfilt(sos, data)
    return filtered_data

def batch_bandpass(data_np: np.ndarray, band: list[float], fs: float, poles=5):

    assert len(band) == data_np.shape[3] + 1, 
    
    sos_list = []
    for i in range(data_np.shape[3]):
        bp = [band[i], band[i+1]]
        sos = sig.butter(poles, bp, 'bandpass', fs=fs, output='sos')
        sos_list.append(sos)
    
    for i in range(data_np.shape[3]):
        filtered = sig.sosfiltfilt(sos_list[i], data_np[:, :, :, i], axis=2)
        data_np[:, :, :, i] = filtered
    
    return data_np

def bandpower(data,low,high,fs):

    # Define window length (2s)
    win = 2* fs
    freqs, psd = sig.welch(data, fs, nperseg=win)
    
    # Find intersecting values in frequency vector
    idx_delta = np.logical_and(freqs >= low, freqs <= high)
    
    # Frequency resolution
    freq_res = freqs[1] - freqs[0] 
    
    # Compute the absolute power by approximating the area under the curve
    power = simps(psd[idx_delta], dx=freq_res)
    
    return power

def bandpowercalc(data_np,band,fs):   
    x = np.zeros([data_np.shape[0],data_np.shape[1],data_np.shape[3]])    # (tralis, node, band)
    for i in range(data_np.shape[1]): 
        for j in range(data_np.shape[0]):
            for k in range(0,data_np.shape[3]):
                data = data_np[j,i,:,k]
                low = band[k]
                high = band[k+1]
                x[j,i,k] = bandpower(data,low,high,fs)

    return x 

In [5]:
def get_nodeFeature(data_np, fs=250):  
    band = list(range(8, 41, 4)) 
    data_np = aggregate_eeg_data(data_np, band)
    data_np = batch_bandpass(data_np, band, fs=fs)
    x = bandpowercalc(data_np, band, fs)
    x = torch.tensor(x, dtype=torch.float32)
    return x

def get_adj(plv, threshold=0.3):
    adj = (plv > threshold).float() * plv
    return adj

def get_input_PyG(data_np, plv, label, threshold=0.3):
    fs = 250
    x = get_nodeFeature(data_np, fs)
    #x = torch.from_numpy(data_np).float()
    print("nodeShape:", x.shape)
    adj = get_adj(plv, threshold)
    
    data_list = []
    for i in range(adj.shape[0]):
        source_nodes, target_nodes = torch.where(adj[i, :, :] >= threshold) 
        edge_index = torch.stack([source_nodes, target_nodes], dim=0)
        node_features = x[i,:,:]      
        data = Data(
            x=node_features,
            edge_index=edge_index,
            y=label[i]
        )
        
        data_list.append(data)

    return data_list