In [None]:
!pip install dask distributed --upgrade
!pip install "dask[dataframe]" --upgrade
!pip install mne --upgrade

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dask
  Downloading dask-2023.6.0-py3-none-any.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m33.4 MB/s[0m eta [36m0:00:00[0m
Collecting distributed
  Downloading distributed-2023.6.0-py3-none-any.whl (976 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m976.1/976.1 kB[0m [31m78.9 MB/s[0m eta [36m0:00:00[0m
Collecting importlib-metadata>=4.13.0 (from dask)
  Downloading importlib_metadata-6.7.0-py3-none-any.whl (22 kB)
Installing collected packages: importlib-metadata, dask, distributed
  Attempting uninstall: dask
    Found existing installation: dask 2022.12.1
    Uninstalling dask-2022.12.1:
      Successfully uninstalled dask-2022.12.1
  Attempting uninstall: distributed
    Found existing installation: distributed 2022.12.1
    Uninstalling distributed-2022.12.1:
      Successfully uninstalled distri

In [None]:
import os
import pandas as pd
import numpy as np
import mne
import tensorflow as tf

# note the code will not overwrite existing output file in case of rerun

# file path seperator (system based, windows \\ and linux /)
seperator = "/"

# channels of intrest
csv_channel_names = ['E1','E2',\
                                  'E3','E4','E5','E6','E7','E8','E9','E10','E11','E12','E13','E14','E15','E16','E18','E19','E20','E21','E22','E23','E24','E25',\
                                  'E26','E27','E28','E29','E30','E31','E32','E33','E34','E35','E36','E37','E38','E39','E40','E41','E42','E44','E45','E46','E47',\
                                  'E50','E51','E52','E53','E54','E55','E57','E58','E59','E60','E61','E62','E64','E65','E66','E67','E69','E70','E71','E72','E74','E75',\
                                  'E76','E77','E78','E79','E80','E82','E83','E84','E85','E86','E87','E89','E90','E91','E92','E93','E95','E96','E97','E98','E100','E101',\
                                  'E102','E103','E104','E105','E106','E108','E109','E110','E111','E112','E114','E115','E116','E117','E118','E121','E122','E123','E124']

# frequency ranges with their name
freqs = {"delta": [1, 4],
                  "theta": [4, 8],
                  "alpha": [8, 12],
                  "beta": [12, 30],
                  "gamma": [30,70],
                  "ALL": [1, 70]}

# name of the model and it's input shape 
models = {"CNN": [len(csv_channel_names), 350]}

# data location
foldername = r""
# output location
outputPath = r""

# disorders of intrest
disorders = ['HBN', 'MDD']

# channel grouping into ROI
channels = {
    "TR":[35,34,52,44,43,42,39,38,37,33,32,41,45,46,47,48,49,28,27],
    "OR":[53,54,55,56,57,58,59,60,61,62,63,64,65,66,67],
    "FR":[26,23,22,19,18,11,25,24,21,20,17,16,15,14,10,31,24],
    "Center":[36, 29, 12,99,94,93,78,72,71,51,50,30,6,5],
    "FL":[108,107,1,0,104,9,8,7,3,4,2],
    "OL":[70,85,84,81,80,79,76,75,74,73,69,68],
    "TL":[106,103,102,100,97,96,95,91,98,92,83,77,90,89,88,87,86,82]}

# ROIs
roi_label = ['TR','OR','FR','Center','FL','OL','TL']


In [None]:
import threading
import dask.dataframe as dd
import csv

# to enhance tensorflow peroformance
tf.compat.v1.enable_eager_execution()
tf.compat.v1.enable_v2_behavior()

# mne base bandpass filter
def band_pass_filter(arr,l_freq=None,h_freq=None,verbose=False):
    montage = mne.channels.make_standard_montage("GSN-HydroCel-128")

    info = mne.create_info(ch_names=csv_channel_names[:len(arr)],ch_types="eeg",sfreq=256,verbose=verbose)

    reader = mne.io.BaseRaw(info, preload=np.array(arr,dtype=np.float64), verbose=verbose)

    reader = reader.filter( l_freq=l_freq, h_freq=h_freq, picks=None,
                            filter_length='auto',
                                l_trans_bandwidth='auto', h_trans_bandwidth='auto', n_jobs='cuda',
                                method='fft', iir_params=None,  phase='zero',
                                fir_window='hamming', fir_design='firwin',
                                pad='reflect_limited', verbose=verbose)
    newarr, _ = reader[:]
    return newarr

# TF resizing for CNN
def resize_cnn(x, old_shape, new_shape):
    tmpX = np.zeros(old_shape)
    tmpX[:,:,0] = x
    tmpX[:,:,1] = x
    tmpX[:,:,2] = x
    arr = tf.image.resize(tmpX, new_shape)
    return np.array(arr[:,:,0])


import re
def build(foldername, filename, outputPath,disorder):
    # reading the file as a string and splitting it in memory seems to enhance performance
    temp = dd.read_table(foldername + seperator + filename, sample=10000000, engine='c', header=None)
    arr2 = temp.to_dask_array()
    arr2 = arr2.persist()
    arr2 = arr2.compute_chunk_sizes()
    arr2 = [temp2[0].compute() for temp2 in arr2]
    
    # this is a filtering step (to exclude channel names that might have accidentally been added to the output file)
    skipPrefix = 'E1,E2,E3,E4,E5,E6,E7,E8,E9,E10,E11,E12,E13,E14,E15,E16,E18,E19,E20,E21,E22,E23,E24,E25,E26,E27,E28,E29,E30,E31,E32,E33,E34,E35,E36,E37,E38,E39,E40,E41,E42,E44,E45,E46,E47,E50,E51,E52,E53,E54,E55,E57,E58,E59,E60,E61,E62,E64,E65,E66,E67,E69,E70,E71,E72,E74,E75,E76,E77,E78,E79,E80,E82,E83,E84,E85,E86,E87,E89,E90,E91,E92,E93,E95,E96,E97,E98,E100,E101,E102,E103,E104,E105,E106,E108,E109,E110,E111,E112,E114,E115,E116,E117,E118,E121,E122,E123,E124'
    prefixes = skipPrefix.split(',')
    prefixIndex = len(prefixes) - 1
    while True:
        cutIndex = arr2[0].rfind(prefixes[prefixIndex])
        if cutIndex != -1:
            cutIndex += len(prefixes[prefixIndex])
            cutIndex += 1
            arr2[0] = arr2[0][cutIndex:]
            break
        prefixIndex -= 1
    # actual data
    data = np.array([np.float32(temp2.split(',')) for temp2 in arr2])
    # in case of error
    if(len(data.shape) == 1):
        assert False


    # fetch length can be adjusted
    fetch_length =  4000
    
    for data_counter in range((len(data[0])//fetch_length)-1):
        for model in models.keys():
            for freq_band in freqs.keys():
                    input_shape = models[model]
                    if os.path.isfile(outputPath+seperator+disorder+seperator+model+seperator+freq_band+seperator+filename.split(".")[0] + str(data_counter)+".csv"):
                        print("Skipping")
                        continue
                    X = np.array(data[:,data_counter*fetch_length:(data_counter+1)*fetch_length])
                    X = band_pass_filter(X,l_freq=freqs[freq_band][0],h_freq=freqs[freq_band][1],verbose=False)
                    X = resize_cnn(X, (len(X), fetch_length, 3), (input_shape[0], input_shape[1]))
                    result = []
                    for sample in range(len(X[0])):
                        tmp = []
                        for channel in range(len(X)):
                            tmp.append(X[channel, sample])

                        for channel in range(len(X)):
                            if len(result) < channel + 1:
                                result.append([])
                            result[channel].append(tmp[channel])
                        tmp[:] = []
                    with open(outputPath+seperator+disorder+seperator+model+seperator+freq_band+seperator+filename.split(".")[0] + str(data_counter)+".csv", 'w',newline='') as csvfile:
                        csvwriter = csv.writer(csvfile)
                        for row in result:
                            csvwriter.writerow(row)
                    X = np.zeros((len(csv_channel_names), fetch_length))
    print(filename)




retVal = []
counter = 0
all_counter = 0
inputs = []
for disorder in disorders:
    print(disorder)
    # make sure the output sub-diractories in the output folder are built
    for model in models.keys():
        for freq_band in freqs.keys():
            for roi in roi_label:
                if not os.path.isdir(outputPath):
                    os.mkdir(outputPath)
                if not os.path.isdir(outputPath+seperator+disorder):
                    os.mkdir(outputPath+seperator+disorder)
                if not os.path.isdir(outputPath+seperator+disorder+seperator+model):
                    os.mkdir(outputPath+seperator+disorder+seperator+model)
                if not os.path.isdir(outputPath+seperator+disorder+seperator+model+seperator+freq_band):
                    os.mkdir(outputPath+seperator+disorder+seperator+model+seperator+freq_band)
                    
    files = os.listdir(foldername+seperator+disorder)
    for filesname in files:
        if os.path.isfile(foldername+seperator+disorder + seperator + filesname):
            inputs.append((foldername+seperator+disorder, filesname, outputPath,disorder))
            # uncomment this line to run synchronously (comment the parallel processing part as well)
            #build(foldername+seperator+disorder, filesname, outputPath,disorder)

# comment this part if you wish to disable parallel processing
###############################
import multiprocessing as mp
pool = mp.Pool()
pool.starmap(build,inputs)
pool.close()
pool.terminate()
pool.join()
###############################

print("Done: ", disorder, model, freq_band)

# necessary only for google colab, to enforce syncing google colab connection to google drive
from google.colab import drive
drive.flush_and_unmount()
print("Synched")




HBN
NDAREH074NG8_RestingState_data.csv
NDARGL800LDW_RestingState_data.csv
NDARDX872VH6_RestingState_data.csv
NDARAA075AMK_RestingState_data.csv
NDARCY178KJP_RestingState_data.csvNDARBY518PRN_RestingState_data.csv

NDARBL117AUV_RestingState_data.csv
NDARAC904DMU_RestingState_data.csv
NDARCT472UJ7_RestingState_data.csv
NDARCJ363KLE_RestingState_data.csv
NDAREF389RY2_RestingState_data.csv
NDARCX053GU5_RestingState_data.csv
NDAREW201WD9_RestingState_data.csv
NDARBU928LV0_RestingState_data.csv
NDARBA521RA8_RestingState_data.csv
NDARAU447JZH_RestingState_data.csv
NDARGD507TDZ_RestingState_data.csv
NDARFL411AT1_RestingState_data.csv
NDARBH024NH2_RestingState_data.csv
NDARFW292PBD_RestingState_data.csv
NDARFY075REK_RestingState_data.csv
NDARCW963FP9_RestingState_data.csv
NDAREW976FNL_RestingState_data.csv
NDAREM703YFD_RestingState_data.csv
NDARDL511UND_RestingState_data.csv
NDARDR296XHN_RestingState_data.csv
NDARAM675UR8_RestingState_data.csv
NDARDJ825GBP_RestingState_data.csv
NDARCD401HGZ_Res