## Tracking File

In [3]:
from pathlib import Path
import os
import re
from pymongo import MongoClient
from boto3.session import Session

### Reading from S3

In [19]:
ACCESS_KEY="ACCESS_KEY_HERE"
SECRET_KEY="SECRET_KEY_HERE"

session = Session(aws_access_key_id=ACCESS_KEY,
                aws_secret_access_key=SECRET_KEY)

In [20]:
s3 = session.resource('s3')
your_bucket = s3.Bucket('msds694-usfcaeeg')

In [42]:
file_list = [Path(s3_file.key) for s3_file in your_bucket.objects.all()]
edf_list = [file for file in file_list if str(os.path.basename(file)).endswith(".edf")]
file_details = [[file.parts[-2], file.parts[-1]] for file in edf_list if
                re.search(r'\d{12}_[A-Z]\d+-\d-\d.+[\.edf]$',str(file))]

In [53]:
def get_attributes(file):
    return (file[0], re.search(r'\d{12}',file[1]).group(),
            re.search(r'[A-Z]\d+-\d-\d',file[1]).group())

file_attributes = [get_attributes(file) for file in file_details]

In [54]:
# List of all files, by participant group, start time, participant id, channel
file_attributes[:5]

[('06_month_EEG', '201701061412', 'A4-1-1'),
 ('06_month_EEG', '201701131043', 'A6-1-1'),
 ('06_month_EEG', '201701131048', 'A6-1-1'),
 ('06_month_EEG', '201701131059', 'A6-1-1'),
 ('06_month_EEG', '201701271322', 'A1-1-1')]

### Reading from Mongo

In [4]:
mongos_ip = '35.164.153.142'
mongos_port = 27017
client = MongoClient(f'mongodb://{mongos_ip}:{mongos_port}/')

In [5]:
def get_collection(db_name,clxn, client):
    db = client[db_name]
    collection = db[clxn]
    return collection

In [274]:
collection = get_collection('test', 'tracking_participant')
master_list = list(collection.find())
master_list[-5:]

[{'_id': ObjectId('5c44186510fd5e358bba5b4c'),
  'file_attributes': ['24_month_EEG', '201804201331', 'B14-1-1'],
  'status': 0},
 {'_id': ObjectId('5c44186510fd5e358bba5b54'),
  'file_attributes': ['24_month_EEG', '201804201340', 'B14-1-1'],
  'status': 0},
 {'_id': ObjectId('5c44186510fd5e358bba5b5c'),
  'file_attributes': ['24_month_EEG', '201804220911', 'B9-1-2'],
  'status': 0},
 {'_id': ObjectId('5c44186510fd5e358bba5b64'),
  'file_attributes': ['24_month_EEG', '201804220912', 'B9-1-2'],
  'status': 0},
 {'_id': ObjectId('5c44186510fd5e358bba5b6c'),
  'file_attributes': ['24_month_EEG', '201804220957', 'B9-2-2'],
  'status': 0}]

In [275]:
files_present = [f["file_attributes"] for f in master_list]
files_present[-5:]

[['24_month_EEG', '201804201331', 'B14-1-1'],
 ['24_month_EEG', '201804201340', 'B14-1-1'],
 ['24_month_EEG', '201804220911', 'B9-1-2'],
 ['24_month_EEG', '201804220912', 'B9-1-2'],
 ['24_month_EEG', '201804220957', 'B9-2-2']]

In [276]:
for f in file_attributes:
    if f not in files_present:
        master_list.append({"file_attributes":f,"status":0})

In [277]:
master_list[-5:]

[{'file_attributes': ('24_month_EEG', '201804201331', 'B14-1-1'), 'status': 0},
 {'file_attributes': ('24_month_EEG', '201804201340', 'B14-1-1'), 'status': 0},
 {'file_attributes': ('24_month_EEG', '201804220911', 'B9-1-2'), 'status': 0},
 {'file_attributes': ('24_month_EEG', '201804220912', 'B9-1-2'), 'status': 0},
 {'file_attributes': ('24_month_EEG', '201804220957', 'B9-2-2'), 'status': 0}]

### Writing to Mongo

In [278]:
for file in master_list:
    collection.update_one(filter = {'file_attributes':file['file_attributes']},
                      update = {'$set':{"status":file['status']}},
                      upsert=True)

### Checking for Unprocessed Files

In [279]:
master_list = list(collection.find())

unprocessed_files = [f["file_attributes"] for f in master_list if f["status"] == 0]

In [282]:
unprocessed_files[:5]

[['06_month_EEG', '201701061412', 'A4-1-1'],
 ['06_month_EEG', '201701131043', 'A6-1-1'],
 ['06_month_EEG', '201701131048', 'A6-1-1'],
 ['06_month_EEG', '201701131059', 'A6-1-1'],
 ['06_month_EEG', '201701271322', 'A1-1-1']]

## Feature Extraction

### Function definitions

In [70]:
import os
import re
import sys
import pywt
import glob
import json
import nolds
import boto3
import pickle
import pyedflib

import numpy as np
from math import log2

from pyrqa.settings import Settings
from pyrqa.neighbourhood import FixedRadius
from pyrqa.computation import RQAComputation
from pyrqa.time_series import SingleTimeSeries

from pyspark.sql import SparkSession, SQLContext
from pyspark import SparkContext
import pyspark.sql.functions as F

from pyrqa.opencl import OpenCL
from pymongo import MongoClient

In [71]:
opencl = OpenCL(platform_id=0, device_ids=(0,))

RUN_NUMBER = 1

# Pyspark mongo config
mongos_ip = '35.164.153.142'
mongos_port = 27017
raw_clxn = 'eeg.eeg_raw'
read_pref = 'readPreference=primaryPreferred'
req_cols = ['raw', 'participant_id', 'participant_group', 'label',
            'startdate', 'sample_rate', 'signals_in_file', '_id',
            'file_duration']
client = MongoClient(f'mongodb://{mongos_ip}:{mongos_port}/')

In [72]:
# pyspark_submit_args = '--packages org.mongodb.spark:mongo-spark-connector_2.11:2.4.0 pyspark-shell'
# os.environ["PYSPARK_SUBMIT_ARGS"] = pyspark_submit_args

# Feature configs
nonrqa_features = ['power', 'sample_entropy', 'hurst_exponent', 'dfa', 'lyap0', 'lyap1', 'lyap2']
rqa_features = ['recurrence_rate', 'determinism', 'laminarity',
                'entropy_diagonal_lines', 'longest_diagonal_line',
                'average_diagonal_line', 'trapping_time']
all_features = nonrqa_features + rqa_features
embedding, tdelay, tau = 10, 2, 30
delete_cols = ['raw', 'n_raw', 't_raw']

In [74]:
def power(y):
    return np.sum(y ** 2) / y.size


def sample_entropy(y):
    # Sample Entropy
    return nolds.sampen(y)


def hurst_exponent(y):
    # Hurst exponent
    return nolds.hurst_rs(y)


def dfa(y):
    # Detrended fluctuation analysis
    return nolds.dfa(y)


# what is emb_dim ?
def lyap(y, emb_dim=10):
    # Lyapunov exponent
    return nolds.lyap_e(y, emb_dim)


function_dict = {"power": power, "sample_entropy": sample_entropy,
                 "hurst_exponent": hurst_exponent, "dfa": dfa, "lyap": lyap}

def get_rqa_features(x, f_label_i, is_fail=False):
    res = {f"{k}_{f_label_i}": np.nan for k in rqa_features}
    if not is_fail:
        for fe in rqa_features:
            res[f"{fe}_{f_label_i}"] = getattr(x, fe)
    return res

In [133]:
def trim_data(data, srate, max_nt=30):
    nt = max_nt * srate              # number of time periods
    if data.shape[0] > 60 * srate:
        m1 = 30 * srate
    else:
        m1 = 0                       # start time
    m2 = m1 + nt                     # end time
    trim_data = data[m1:m2]          # truncating data to the max number of time periods (in s)
    return trim_data


def features_settings(data, srate, wavelet='db4', mode='cpd'):

    w = pywt.Wavelet(wavelet)
    a_orig = data - np.mean(data)
    a = a_orig
    nbands = int(log2(srate)) - 1

    rec_a, rec_d = [], []                # all the approximations and details

    for i in range(nbands):
        (a, d) = pywt.dwt(a, w, mode)
        f = pow(np.sqrt(2.0), i + 1)
        rec_a.append(a / f)
        rec_d.append(d / f)

    f_labels, freqband = ['A0'], [a_orig]  # A0 is the original signal
    fs = [srate]
    f = fs[0]
    N = len(a_orig)

    for j, r in enumerate(rec_d):
        freq_name = 'D' + str(j + 1)
        f_labels.append(freq_name)
        freqband.append(r[0:N])          # wavelet details for this band
        fs.append(f)
        f = f / 2.0

    # We need one more
    f = f / 2.0
    fs.append(f)

    j = len(rec_d) - 1
    freq_name = 'A' + str(j + 1)
    f_labels.append(freq_name)
    freqband.append(rec_a[j])       # wavelet approximation for this band
    res = {}
    res['freqband'] = freqband
    res['f_labels'] = f_labels
    return res

In [167]:
def compute_non_rqa_features(freqband, f_labels, nonrqa_features=nonrqa_features):

    feature_calc = {}
    error_feet = {}

    for i, y in enumerate(freqband):
        if 'lyap' in [f[:-1] for f in nonrqa_features]:
            try:
                lyap = function_dict['lyap'](y, embedding)
                for j in range(0, 3):
                    feature_calc[f'lyap{j}' + '_' + f_labels[i]] = lyap[j]
            except Exception as e:
                for j in range(0, 3):
                    feature_calc[f'lyap{j}' + '_' + f_labels[i]] = np.nan
                error_feet = {**{str('lyap_' + f_labels[i]): e}, **error_feet}
        for feat in [f for f in nonrqa_features if not f.startswith('lyap')]:
            try:
                feature_calc[feat + "_" + f_labels[i]] = function_dict[feat](y)
            except Exception as e:
                feature_calc[feat + "_" + f_labels[i]] = np.nan
                error_feet = {**{str(feat + "_" + f_labels[i]): e}, **error_feet}

    feature_calc['error_nonrqa_feat'] = error_feet
    return feature_calc


def compute_rqa_features(freqband, f_labels):

    opencl = OpenCL(platform_id=0, device_ids=(0,))

    feature_calc = {}
    error_rqa_feat = {}

    for i, y in enumerate(freqband):

        y = SingleTimeSeries(y, embedding_dimension=embedding, time_delay=tdelay)
        settings = Settings(y, neighbourhood=FixedRadius(tau))
        computation = RQAComputation.create(settings, verbose=True, opencl=opencl)
        try:
            result = computation.run()
            result = get_rqa_features(result, f_labels[i])
        except Exception as e:
            error_rqa_feat['error_' + f_labels[i]] = e
            result = get_rqa_features(None, f_labels[i], is_fail=True)

        feature_calc = {**feature_calc, **result}

    feature_calc = {**feature_calc, **error_rqa_feat}
    return feature_calc


def fix_dtypes(x):
    for key in delete_cols:
        del x[key]
    del x['freqband']
    x['_id'] = str(x['_id'])
    x['unique_id'] = x.pop('_id')
    for k, v in x.items():
        if isinstance(v, np.floating):
            x[k] = float(x[k])
        if isinstance(v, np.integer):
            x[k] = int(x[k])
    return x

### Extracting unprocessed files

In [283]:
collection = get_collection('eeg', 'eeg_raw')

In [284]:
query = {}
for col in req_cols:
    query[col] = 1
query

{'raw': 1,
 'participant_id': 1,
 'participant_group': 1,
 'label': 1,
 'startdate': 1,
 'sample_rate': 1,
 'signals_in_file': 1,
 '_id': 1,
 'file_duration': 1}

In [286]:
unprocessed_files = [[f[0],int(f[1][:4]),int(f[1][4:6]),int(f[1][6:8]),
                      int(f[1][8:10]),int(f[1][10:]),f[2]] for f in unprocessed_files]

In [287]:
unprocessed_files[-5:]

[['24_month_EEG', 2018, 4, 20, 13, 31, 'B14-1-1'],
 ['24_month_EEG', 2018, 4, 20, 13, 40, 'B14-1-1'],
 ['24_month_EEG', 2018, 4, 22, 9, 11, 'B9-1-2'],
 ['24_month_EEG', 2018, 4, 22, 9, 12, 'B9-1-2'],
 ['24_month_EEG', 2018, 4, 22, 9, 57, 'B9-2-2']]

In [288]:
df_list = []
for file in unprocessed_files:
    df_list.extend(list(collection.find({'participant_group': file[0],'startdate_year':file[1],
                                    'startdate_month':file[2],'startdate_day':file[3],
                                    'starttime_hour':file[4],'starttime_minute':file[5],
                                    'participant_id': file[6]},query)))

### Updating channel level tracker

In [291]:
collection = get_collection('test', 'tracking')

In [292]:
#collection.drop()

In [293]:
master_channel_list = list(collection.find())
#master_channel_list = []
master_channel_list[-5:]

[{'_id': ObjectId('5c451bb210fd5e358bbd38ef'),
  'channel_attributes': ['B9-2-2',
   '24_month_EEG',
   '2018-04-22 09:57:33',
   'Fp2'],
  'n_attempts': 0,
  'status': 0},
 {'_id': ObjectId('5c451bb210fd5e358bbd38f7'),
  'channel_attributes': ['B9-2-2',
   '24_month_EEG',
   '2018-04-22 09:57:33',
   'O1'],
  'n_attempts': 0,
  'status': 0},
 {'_id': ObjectId('5c451bb210fd5e358bbd38ff'),
  'channel_attributes': ['B9-2-2',
   '24_month_EEG',
   '2018-04-22 09:57:33',
   'O2'],
  'n_attempts': 0,
  'status': 0},
 {'_id': ObjectId('5c451bb210fd5e358bbd3907'),
  'channel_attributes': ['B9-2-2',
   '24_month_EEG',
   '2018-04-22 09:57:33',
   'T7'],
  'n_attempts': 0,
  'status': 0},
 {'_id': ObjectId('5c451bb210fd5e358bbd390f'),
  'channel_attributes': ['B9-2-2',
   '24_month_EEG',
   '2018-04-22 09:57:33',
   'T8'],
  'n_attempts': 0,
  'status': 0}]

In [294]:
channels_processed = [f["channel_attributes"] for f in master_channel_list]
channels_processed[-5:]

[['B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'Fp2'],
 ['B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'O1'],
 ['B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'O2'],
 ['B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'T7'],
 ['B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'T8']]

In [295]:
def get_channel_attributes(df):
    return (df["participant_id"],df["participant_group"],str(df["startdate"]),df["label"])

channel_wise_list = [get_channel_attributes(df) for df in df_list]

In [296]:
channel_wise_list[-5:]

[('B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'Fp2'),
 ('B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'O1'),
 ('B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'O2'),
 ('B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'T7'),
 ('B9-2-2', '24_month_EEG', '2018-04-22 09:57:33', 'T8')]

In [297]:
for ch in channel_wise_list:
    if ch not in channels_processed:
        master_channel_list.append({"channel_attributes":ch,"status":0,"n_attempts":0})

In [298]:
for channel in master_channel_list:
    collection.update_one(filter = {'channel_attributes':channel['channel_attributes']},
                      update = {'$set':{"status":channel['status'],
                                        "n_attempts":channel['n_attempts']}},
                      upsert=True)

### Extracting features and stuff

In [348]:
def extract_features(select_list):
    sc = SparkSession\
            .builder\
            .appName("myEEGSession")\
            .config("spark.mongodb.input.uri",
                    f"mongodb://{mongos_ip}:{mongos_port}/{raw_clxn}?{read_pref}") \
            .config('spark.jars.packages', 'org.mongodb.spark:mongo-spark-connector_2.11:2.4.0') \
            .getOrCreate()

    #rdd = sc.sparkContext.parallelize(df_list)
    rdd = sc.sparkContext.parallelize(select_list)

    update_list = [(ch["participant_id"],ch["participant_group"],ch["startdate"],ch["label"]) for ch in select_list]

    collection = get_collection('test', 'tracking')
    for channel in update_list:
        collection.update_one(filter = {'channel_attributes':channel},
                              update = {'$inc':{"status":1,
                                                "n_attempts":1}},
                              upsert=True)

    rdd = rdd.map(lambda x: {**{'n_raw': np.array(x['raw'])}, **x})
    rdd = rdd.map(lambda x: {**{'t_raw': trim_data(x['n_raw'], x['sample_rate'])}, **x})
    rdd = rdd.map(lambda x: {**features_settings(x['t_raw'], x['sample_rate']), **x})
    # rdd = rdd.map(lambda x: {**compute_non_rqa_features(x['freqband'], x['f_labels']), **x})
    rdd = rdd.map(lambda x: {**compute_rqa_features(x['freqband'], x['f_labels']), **x})
    rdd = rdd.map(lambda x: fix_dtypes(x))
    features = rdd.collect()

    with open('test.json', 'w') as file:
        for document in features:
            file.write(json.dumps(document))
            file.write("\n")

    error_list = [{'error_type': k, 'error_msg': v, 'unique_id': d['_id'],
                   'participant_group': d['participant_group'], 'participant_id': d['participant_id'],
                   'startdate':d['startdate'],'label':d['label'],'run_num': RUN_NUMBER}
                  for d in features for k, v in d.items() if k.startswith('error') and len(v) > 0]

    print(features)
    print(error_list)

    collection = get_collection('test', 'eeg_features')
    collection.insert_many(features)

    collection = get_collection('test', 'tracking')
    for d in error_list:
        collection.update_one(filter = {'channel_attributes':(d['participant_id'], d['participant_group'],
                                                              d['startdate'],d['label'])},
                              update = {'$inc':{"status":1},'$set':{"error_list":d}}, upsert=True)

In [None]:
select_list = df_list[:8]
extract_features(select_list)

In [353]:
collection = get_collection('test', 'tracking')
failed_files = list(collection.find({"status":2}))
failed_files_to_run = [ch["channel_attributes"] for ch in failed_files if ch["n_attempts"] < 5]

collection = get_collection('eeg', 'eeg_raw')
df_failed = []

for file in failed_files_to_run:
    df_failed.extend(list(collection.find({'participant_group':file[1], 'startdate':file[2],
                                         'label':file[3], 'participant_id':file[0]},query)))

In [349]:
if (len(df_failed) > 0 and RUN_NUMBER < 5):
    collection = get_collection('test', 'tracking')
    for channel in failed_files_to_run:
        collection.update_one(filter = {'channel_attributes':channel},
                              update = {'$inc':{"n_attempts":1},
                                        '$set':{"status":0}},
                              upsert=True)
    RUN_NUMBER+=1
    select_list = df_failed
    extract_features(select_list)