In [2]:
import re
import os
import pandas as pd
import torch
import matplotlib.pyplot as plt
import math
import pyarrow
import fastparquet
import numpy as np



In [3]:
import spkit as sp
from spkit.data import load_data
sp.__version__

install cvxpy for L1 norm minimization for PeriodStrength fun (Ramanujan methods)


'0.0.9.4'

In [64]:

ch_names = ['AF3', 'AF4', 'PZ', 'T7', 'T8']



['eeg_AF3.csv', 'eeg_AF4.csv', 'eeg_PZ.csv', 'eeg_T7.csv', 'eeg_T8.csv']

In [106]:
def readEEgs(root_dir):
    dirs = os.listdir(root_dir)
    eegs = []
    for dir in dirs:
        path = os.path.join(root_dir,dir)
        eeg = pd.read_parquet(path)
        eeg = torch.from_numpy(eeg.values)
        eegs.append(eeg)
    return eegs


def concatenateChannels(eeg_channels):
    shape = eeg_channels[0].shape

    eeg_concat = torch.zeros((shape[0],shape[1],len(eeg_channels)))

    for i, eeg in enumerate(eeg_channels):
        eeg_concat[:,:,i] = eeg
    
    return(eeg_concat)


#plots
def plotMultiChannel(eeg):
    t = np.arange(eeg.shape[0])/fs
    plt.figure(figsize=(15,8))
    plt.subplot(221)
    plt.plot(t,eeg+np.arange(-2,3)*200)
    plt.xlim([t[0],t[-1]])
    plt.xlabel('time (sec)')

    plt.grid()
    plt.title(' 5 channel - EEG Signal (filtered) ')

In [102]:
def ATAR_Filtering(eegs):
    filtered_eeg = torch.zeros(eegs.shape)
    for i, eeg in enumerate(eegs):
        eeg = sp.eeg.ATAR(eeg.numpy().copy(),wv='db4', winsize=128, beta=0.1,thr_method='ipr',OptMode='soft', verbose=1)
        filtered_eeg[i] = torch.from_numpy(eeg)
    return filtered_eeg


Wave-Separator

In [125]:
def WaveSeparatorTorch(eeg):
    start = 0
    end = -1

    waves = []

    for w_ind in range(eeg.shape[0]-1):
        past_wave = eeg[w_ind-1]
        this_wave = eeg[w_ind]
        next_wave = eeg[w_ind+1]

        if( (past_wave >= this_wave) and (this_wave < next_wave) ):
            end = w_ind
            wave = eeg[start:end+1]
            waves.append(wave)
            start = end
            end = -1

    return waves[1:]

In [126]:
def WaveSeparator(eeg):
    start = 0
    end = -1

    waves = []

    for w_ind in range(eeg.shape[0]-1):
        past_wave = eeg.iloc[w_ind-1]
        this_wave = eeg.iloc[w_ind]
        next_wave = eeg.iloc[w_ind+1]

        

  
        if( (past_wave >= this_wave and this_wave < next_wave) ):

            end = w_ind

            wave = eeg.iloc[start:end+1]
            waves.append(wave)
            start = end
            end= -1
                

    return waves[1:-1]



In [127]:
def defineWaves(waves):
    new_waves = []
    means = pd.Series(dtype='float32')
    for wave in waves:
        means = means.append(pd.Series([wave.mean()], dtype='float32'))
    tot_mean = means.mean()
   

    for wave in waves:
        w_mean = wave.mean()
        new_wave = pd.Series([round(w_mean - tot_mean, 2)])
        for w in wave:
            new_wave = new_wave.append(pd.Series([round(w - w_mean,2)], dtype='float32'))
        new_waves.append(torch.from_numpy(new_wave.values))

    return new_waves

In [128]:
def defineWavesTorch(waves, eeg_mean):

    for ind, wave in enumerate(waves):
        
        wave = wave - eeg_mean
        wave = torch.round(wave * 10**3) / (10**3)
        
        waves[ind] = wave

    return waves

In [129]:
def cutAndMergeWaves(d_waves):

    merged_waves = torch.zeros((len(d_waves),10))

    for i, wave in enumerate(d_waves):

        if (wave.shape[0] <= 10):


            merged_waves[ i, :wave.shape[0]] = wave

    merged_waves = merged_waves[ merged_waves.sum(dim=1) != 0 ]

        
    return merged_waves

In [141]:
def normalize_eeg(eeg):
    if(type(eeg) == pd.core.frame.DataFrame):
        eeg = torch.from_numpy(eeg.values)
    elif(type(eeg) == np.ndarray ):
        eeg = torch.from_numpy(eeg)

    min_val = torch.min(eeg)
    max_val = torch.max(eeg)
    return (eeg - min_val) / (max_val - min_val)


In [131]:
def EEGPipelineSingle(eeg):
    eeg = normalize_eeg(eeg)
    eeg_mean = eeg.mean()
    waves = WaveSeparatorTorch(eeg)
    waves = defineWavesTorch(waves, eeg_mean)
    waves = cutAndMergeWaves(waves)
    return waves

In [132]:
def EEGPipelineElectrode(electrode_data):
    waves_list = []
    electrode_data = torch.from_numpy(electrode_data.values)
    for eeg in (electrode_data):
        eeg = normalize_eeg(eeg)
        eeg_mean = eeg.mean()
        waves = WaveSeparatorTorch(eeg)
        waves = defineWavesTorch(waves, eeg_mean)
        waves = cutAndMergeWaves(waves)
        waves_list.append(waves)
        
    return torch.cat(waves_list, dim=0)

In [None]:
def remove_trailing_zeros(tensor):
    # Get the non-zero indices
    non_zero_indices = tensor.nonzero()

    # Check if the tensor has any non-zero values
    if non_zero_indices.numel() == 0:
        return tensor
    else:
        # Return the tensor with only the non-zero values
        return tensor[:non_zero_indices[-1] + 1]

'''
# Cuts out from each wave the artifact of the added zeros if it is smaller than 10 timesteps
def reduceWaves(waves):
    for wave in waves:
        wave = remove_trailing_zeros(wave)
        '''

In [207]:
def scrambleAndOrderWaves(eeg_waves):

    if(type(eeg_waves) == np.ndarray):
        eeg_waves = torch.from_numpy(eeg_waves)
        print("Changed from ndarray to tensor")

    eeg_waves = mixRows(eeg_waves)
    ordered_waves = []
    waves = []
    for size in range(10,-1,-1):

        
        for wave in eeg_waves:

            #Find last non-zero element (= length)
            length = torch.nonzero(wave, as_tuple=True)[0].max() 
           
            
            if (length == size):
                waves.append(wave)

    return torch.stack(waves) 
   

def mixRows(eeg):
    permutation = torch.randperm(eeg.size()[0])
    mixed_tensor = eeg[permutation, :]
    return mixed_tensor

In [None]:
# all_eeg = np.load(r"C:\Users\bruno\OneDrive\Desktop\BrainReader RESEARCH\Datasets\Imagenet_EEG_parquet\Train\ATAR_numpy_eeg\ATAR_all.npy")
def draftEEGATARPipeline(ATAR_eeg):
    shape = ATAR_eeg.shape
    waves_list = []
    for c in range(0,shape[2]):
        for i in range(0,shape[0]):
            eeg = ATAR_eeg[i,:,c]

            eeg = normalize_eeg(eeg)
            eeg_mean = eeg.mean()
            waves = WaveSeparatorTorch(eeg)
            waves = defineWavesTorch(waves, eeg_mean)
            waves = cutAndMergeWaves(waves)
            waves_list.append(waves)

    waves_list = torch.cat(waves_list, dim=0)
    scrambleAndOrderWaves(waves)
    
    return scrambleAndOrderWaves(waves_list)