In [1]:
from pyspark.sql import SparkSession, SQLContext
import pyspark.sql.functions as F
from pyspark.sql.functions import isnan, when, count, col
from pyspark.ml.feature import Imputer
from pyspark.sql import DataFrameReader
from pyspark.sql.types import *
from pyspark.ml.feature import VectorAssembler,StringIndexer
from pyspark.sql.window import Window

In [2]:
import os
pyspark_submit_args = '--packages org.mongodb.spark:mongo-spark-connector_2.11:2.4.0 pyspark-shell'
os.environ["PYSPARK_SUBMIT_ARGS"] = pyspark_submit_args

In [3]:
sc = SparkSession\
    .builder\
    .appName("myEEGSession")\
    .config("spark.mongodb.input.uri", "mongodb://54.188.74.0/test_hx") \
    .config('spark.jars.packages','org.mongodb.spark:mongo-spark-connector_2.11:2.4.0')\
    .getOrCreate()

In [4]:
ss = SparkSession.builder.getOrCreate()

In [5]:
logger = sc._jvm.org.apache.log4j
logger.LogManager.getRootLogger().setLevel(logger.Level.FATAL)

### Load the Data

In [6]:
df = sc.read.format("com.mongodb.spark.sql.DefaultSource").option("uri", "mongodb://54.188.74.0/test_hx.eeg_features").load()

In [7]:
df.count()

392

In [8]:
print(df.columns)

['_id', 'dfa_A0', 'dfa_A7', 'dfa_D1', 'dfa_D2', 'dfa_D3', 'dfa_D4', 'dfa_D5', 'dfa_D6', 'dfa_D7', 'error_nonrqa_feat', 'f_labels', 'file_duration', 'hurst_exponent_A0', 'hurst_exponent_A7', 'hurst_exponent_D1', 'hurst_exponent_D2', 'hurst_exponent_D3', 'hurst_exponent_D4', 'hurst_exponent_D5', 'hurst_exponent_D6', 'hurst_exponent_D7', 'label', 'lyap0_A0', 'lyap0_A7', 'lyap0_D1', 'lyap0_D2', 'lyap0_D3', 'lyap0_D4', 'lyap0_D5', 'lyap0_D6', 'lyap0_D7', 'lyap1_A0', 'lyap1_A7', 'lyap1_D1', 'lyap1_D2', 'lyap1_D3', 'lyap1_D4', 'lyap1_D5', 'lyap1_D6', 'lyap1_D7', 'lyap2_A0', 'lyap2_A7', 'lyap2_D1', 'lyap2_D2', 'lyap2_D3', 'lyap2_D4', 'lyap2_D5', 'lyap2_D6', 'lyap2_D7', 'participant_group', 'participant_id', 'power_A0', 'power_A7', 'power_D1', 'power_D2', 'power_D3', 'power_D4', 'power_D5', 'power_D6', 'power_D7', 'sample_entropy_A0', 'sample_entropy_A7', 'sample_entropy_D1', 'sample_entropy_D2', 'sample_entropy_D3', 'sample_entropy_D4', 'sample_entropy_D5', 'sample_entropy_D6', 'sample_entropy

## Joining Tables

In [167]:
labels = df.select('label').distinct().rdd.keys().collect()

In [168]:
master_channel_list = ["Fp1","Fp2","F7","F3","Fz","F4","F8","T7","C3","Cz",
                       "C4","T8","P7","P3","Pz","P4","P8","O1","O2"]

In [169]:
df_list = [df.filter(col("label")== l) for l in labels]

In [170]:
df.count()

248

In [171]:
columns_all = df.columns

In [172]:
columns_except = ['_id','participant_group','participant_id','f_labels',
                  'file_duration', 'label',  'sample_rate','signals_in_file','startdate','unique_id']

In [173]:
columns_rename = [c for c in columns_all if c not in columns_except]

In [174]:
labels

['O1', 'C3', 'C4', 'Fp1', 'T7', 'O2', 'T8', 'Fp2']

In [175]:
columns_rename

['dfa_A0',
 'dfa_A7',
 'dfa_D1',
 'dfa_D2',
 'dfa_D3',
 'dfa_D4',
 'dfa_D5',
 'dfa_D6',
 'dfa_D7',
 'hurst_exponent_A0',
 'hurst_exponent_A7',
 'hurst_exponent_D1',
 'hurst_exponent_D2',
 'hurst_exponent_D3',
 'hurst_exponent_D4',
 'hurst_exponent_D5',
 'hurst_exponent_D6',
 'hurst_exponent_D7',
 'lyap0_A0',
 'lyap0_A7',
 'lyap0_D1',
 'lyap0_D2',
 'lyap0_D3',
 'lyap0_D4',
 'lyap0_D5',
 'lyap0_D6',
 'lyap0_D7',
 'lyap1_A0',
 'lyap1_A7',
 'lyap1_D1',
 'lyap1_D2',
 'lyap1_D3',
 'lyap1_D4',
 'lyap1_D5',
 'lyap1_D6',
 'lyap1_D7',
 'lyap2_A0',
 'lyap2_A7',
 'lyap2_D1',
 'lyap2_D2',
 'lyap2_D3',
 'lyap2_D4',
 'lyap2_D5',
 'lyap2_D6',
 'lyap2_D7',
 'power_A0',
 'power_A7',
 'power_D1',
 'power_D2',
 'power_D3',
 'power_D4',
 'power_D5',
 'power_D6',
 'power_D7',
 'sample_entropy_A0',
 'sample_entropy_A7',
 'sample_entropy_D1',
 'sample_entropy_D2',
 'sample_entropy_D3',
 'sample_entropy_D4',
 'sample_entropy_D5',
 'sample_entropy_D6',
 'sample_entropy_D7']

In [176]:
for i in range(len(labels)):
    for c in df_list[i].columns:
        if c in columns_rename:
            df_list[i] = df_list[i].withColumnRenamed(c, c+"_"+labels[i])

In [177]:
for i in range(len(labels)):
    print(df_list[i].columns)
    print('\n')

['_id', 'dfa_A0_O1', 'dfa_A7_O1', 'dfa_D1_O1', 'dfa_D2_O1', 'dfa_D3_O1', 'dfa_D4_O1', 'dfa_D5_O1', 'dfa_D6_O1', 'dfa_D7_O1', 'f_labels', 'file_duration', 'hurst_exponent_A0_O1', 'hurst_exponent_A7_O1', 'hurst_exponent_D1_O1', 'hurst_exponent_D2_O1', 'hurst_exponent_D3_O1', 'hurst_exponent_D4_O1', 'hurst_exponent_D5_O1', 'hurst_exponent_D6_O1', 'hurst_exponent_D7_O1', 'label', 'lyap0_A0_O1', 'lyap0_A7_O1', 'lyap0_D1_O1', 'lyap0_D2_O1', 'lyap0_D3_O1', 'lyap0_D4_O1', 'lyap0_D5_O1', 'lyap0_D6_O1', 'lyap0_D7_O1', 'lyap1_A0_O1', 'lyap1_A7_O1', 'lyap1_D1_O1', 'lyap1_D2_O1', 'lyap1_D3_O1', 'lyap1_D4_O1', 'lyap1_D5_O1', 'lyap1_D6_O1', 'lyap1_D7_O1', 'lyap2_A0_O1', 'lyap2_A7_O1', 'lyap2_D1_O1', 'lyap2_D2_O1', 'lyap2_D3_O1', 'lyap2_D4_O1', 'lyap2_D5_O1', 'lyap2_D6_O1', 'lyap2_D7_O1', 'participant_group', 'participant_id', 'power_A0_O1', 'power_A7_O1', 'power_D1_O1', 'power_D2_O1', 'power_D3_O1', 'power_D4_O1', 'power_D5_O1', 'power_D6_O1', 'power_D7_O1', 'sample_entropy_A0_O1', 'sample_entropy_A7

In [178]:
temp = df.select(*['participant_group','participant_id', 'startdate','file_duration']).distinct()

In [179]:
temp.show()

+-----------------+--------------+-------------------+-------------+
|participant_group|participant_id|          startdate|file_duration|
+-----------------+--------------+-------------------+-------------+
|     12_month_EEG|        B1-2-3|2016-12-09 12:12:26|          598|
|     12_month_EEG|       B20-1-1|2016-09-09 11:09:49|          599|
|     12_month_EEG|       A14-1-1|2017-04-28 16:16:14|          598|
|     12_month_EEG|       B22-1-1|2017-06-23 14:43:19|          599|
|     06_month_EEG|        A8-1-3|2017-06-09 16:51:00|          599|
|     12_month_EEG|        B1-2-3|2016-12-09 12:23:07|          179|
|     12_month_EEG|        B8-1-1|2016-08-12 16:36:47|          598|
|     12_month_EEG|       B24-2-2|2017-03-10 14:09:35|          598|
|     12_month_EEG|       B25-1-1|2017-01-20 16:21:38|          599|
|     12_month_EEG|        B9-1-2|2017-03-10 12:59:32|          611|
|     12_month_EEG|       B24-1-2|2017-03-10 14:30:52|          598|
|     24_month_EEG|        B9-1-2|

In [180]:
for i in range(len(df_list)):
    df_list[i] = df_list[i].drop(*['_id','f_labels','label','sample_rate','signals_in_file','unique_id','file_duration'])



In [181]:
for i in range(len(df_list)):
    temp= temp.join(df_list[i],['participant_group','participant_id', 'startdate'])

In [182]:
df = temp
del temp

In [183]:
df.columns

['participant_group',
 'participant_id',
 'startdate',
 'file_duration',
 'dfa_A0_O1',
 'dfa_A7_O1',
 'dfa_D1_O1',
 'dfa_D2_O1',
 'dfa_D3_O1',
 'dfa_D4_O1',
 'dfa_D5_O1',
 'dfa_D6_O1',
 'dfa_D7_O1',
 'hurst_exponent_A0_O1',
 'hurst_exponent_A7_O1',
 'hurst_exponent_D1_O1',
 'hurst_exponent_D2_O1',
 'hurst_exponent_D3_O1',
 'hurst_exponent_D4_O1',
 'hurst_exponent_D5_O1',
 'hurst_exponent_D6_O1',
 'hurst_exponent_D7_O1',
 'lyap0_A0_O1',
 'lyap0_A7_O1',
 'lyap0_D1_O1',
 'lyap0_D2_O1',
 'lyap0_D3_O1',
 'lyap0_D4_O1',
 'lyap0_D5_O1',
 'lyap0_D6_O1',
 'lyap0_D7_O1',
 'lyap1_A0_O1',
 'lyap1_A7_O1',
 'lyap1_D1_O1',
 'lyap1_D2_O1',
 'lyap1_D3_O1',
 'lyap1_D4_O1',
 'lyap1_D5_O1',
 'lyap1_D6_O1',
 'lyap1_D7_O1',
 'lyap2_A0_O1',
 'lyap2_A7_O1',
 'lyap2_D1_O1',
 'lyap2_D2_O1',
 'lyap2_D3_O1',
 'lyap2_D4_O1',
 'lyap2_D5_O1',
 'lyap2_D6_O1',
 'lyap2_D7_O1',
 'power_A0_O1',
 'power_A7_O1',
 'power_D1_O1',
 'power_D2_O1',
 'power_D3_O1',
 'power_D4_O1',
 'power_D5_O1',
 'power_D6_O1',
 'power_D7_O1',


In [186]:
df.select(*['participant_group','participant_id', 'startdate','file_duration']).sort('participant_id').show()

+-----------------+--------------+-------------------+-------------+
|participant_group|participant_id|          startdate|file_duration|
+-----------------+--------------+-------------------+-------------+
|     06_month_EEG|       A10-1-1|2017-06-23 15:26:25|          598|
|     12_month_EEG|       A14-1-1|2017-04-28 16:16:14|          598|
|     06_month_EEG|        A4-1-1|2017-01-06 14:12:07|          598|
|     06_month_EEG|        A8-1-3|2017-06-09 16:51:00|          599|
|     06_month_EEG|        A8-2-3|2017-06-09 16:31:57|          599|
|     12_month_EEG|        B1-2-3|2016-12-09 12:12:26|          598|
|     12_month_EEG|        B1-2-3|2016-12-09 12:23:07|          179|
|     12_month_EEG|        B1-3-3|2016-12-09 11:53:32|          599|
|     12_month_EEG|       B12-1-1|2017-03-24 14:57:27|          598|
|     12_month_EEG|       B13-1-1|2016-12-23 16:50:53|          599|
|     12_month_EEG|       B17-1-1|2017-02-24 11:52:59|          598|
|     12_month_EEG|       B20-1-1|

In [193]:
window = Window.\
              partitionBy('participant_id','participant_group').\
              orderBy(df['startdate'].desc())
df = df.withColumn("rank_start_dt",F.dense_rank().over(window))

In [198]:
df = df.filter("rank_start_dt == 1")

In [200]:
df_meta = sc.read.format("com.mongodb.spark.sql.DefaultSource").option("uri", "mongodb://52.40.36.24/eeg.eeg_metadata").load()
df_meta = df_meta.drop("_id","num_recording")

In [201]:
df_meta.columns

['Delivery_type',
 'Gender',
 'Gestational_Age',
 'Maternal_age',
 'Multiple_births',
 'Prematurity_Level',
 'Relative_size',
 'Weight_gms',
 'participant_group',
 'participant_id']

In [208]:
features = df.join(df_meta, on = ['participant_group','participant_id'] , how='inner')
features = features.drop('participant_group','participant_ID','startdate','file_duration','rank_start_dt')


In [210]:
del df
del df_meta

## Feature Engineering

In [211]:
from collections import Counter
Counter([d[1] for d in features.dtypes])

Counter({'double': 509, 'string': 3})

In [212]:
feat_str = [d[0] for d in features.dtypes if d[1] == 'string']
feat_dbl = [d[0] for d in features.dtypes if d[1] in ['double','int']]

In [233]:
nulls_df = features.select([((count(when(isnan(c)| col(c).isNull(), c)))*100.0/count(c)).alias(c)  for c in features.columns])

In [215]:
nulls_df.show()

+-----------------+------------------+------------------+------------------+------------------+-----------------+------------------+------------------+------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------------+-

In [217]:
null_threshold = 5.0
def multicontd(x):
    if x < null_threshold and x != 0:
        return True

In [241]:
null_cols = nulls_df.columns
null_counts = list(nulls_df.rdd.map(tuple).collect()[0])
null_cols_imp = [null_cols[i] for i,n in enumerate(null_counts) if multicontd(n)]

In [242]:
null_cols_drop = [null_cols[i] for i,n in enumerate(null_counts) if n >= null_threshold]
features = features.drop(*null_cols_drop)

In [220]:
del nulls_df

In [243]:
null_cols_imp_db = [f for f in null_cols_imp if f in feat_dbl]
null_cols_impNew = [col+'_imp' for col in null_cols_imp_db]

In [244]:
imputer = Imputer(inputCols= null_cols_imp_db , outputCols=null_cols_impNew,strategy='median')
features = imputer.fit(features).transform(features).drop(*null_cols_imp_db)

In [245]:
null_cols_imp = [f for f in null_cols_imp if f in feat_str]
null_cols_impNew = [col+'_imp' for col in null_cols_imp]

In [284]:
features = features.drop('Delivery_type')

In [289]:
feat_str.remove('Delivery_type')
feat_str

['Gender', 'Relative_size']

In [290]:
def indexStringColumns(df, cols):
    #variable newdf will be updated several times
    newdf = df
    
    for c in cols:
        si = StringIndexer(inputCol=c, outputCol=c+"-num")
        sm = si.fit(newdf)
        newdf = sm.transform(newdf).drop(c)
        newdf = newdf.withColumnRenamed(c+"-num", c)
    return newdf

dfnumeric = indexStringColumns(features, feat_str)

In [294]:
temp_cols = dfnumeric.columns
temp_cols.remove('Prematurity_Level')
temp_cols

['lyap0_A0_O1',
 'lyap0_A7_O1',
 'lyap0_D1_O1',
 'lyap0_D2_O1',
 'lyap0_D3_O1',
 'lyap0_D4_O1',
 'lyap0_D5_O1',
 'lyap0_D6_O1',
 'lyap0_D7_O1',
 'lyap1_A0_O1',
 'lyap1_A7_O1',
 'lyap1_D1_O1',
 'lyap1_D2_O1',
 'lyap1_D3_O1',
 'lyap1_D4_O1',
 'lyap1_D5_O1',
 'lyap1_D6_O1',
 'lyap1_D7_O1',
 'lyap2_A0_O1',
 'lyap2_A7_O1',
 'lyap2_D1_O1',
 'lyap2_D2_O1',
 'lyap2_D3_O1',
 'lyap2_D4_O1',
 'lyap2_D5_O1',
 'lyap2_D6_O1',
 'lyap2_D7_O1',
 'power_A0_O1',
 'power_A7_O1',
 'power_D1_O1',
 'power_D2_O1',
 'power_D3_O1',
 'power_D4_O1',
 'power_D5_O1',
 'power_D6_O1',
 'power_D7_O1',
 'sample_entropy_A0_O1',
 'sample_entropy_A7_O1',
 'sample_entropy_D1_O1',
 'sample_entropy_D2_O1',
 'sample_entropy_D3_O1',
 'sample_entropy_D4_O1',
 'sample_entropy_D5_O1',
 'sample_entropy_D6_O1',
 'sample_entropy_D7_O1',
 'lyap0_A0_C3',
 'lyap0_A7_C3',
 'lyap0_D1_C3',
 'lyap0_D2_C3',
 'lyap0_D3_C3',
 'lyap0_D4_C3',
 'lyap0_D5_C3',
 'lyap0_D6_C3',
 'lyap0_D7_C3',
 'lyap1_A0_C3',
 'lyap1_A7_C3',
 'lyap1_D1_C3',
 'lyap1

In [296]:
features = dfnumeric.select(*temp_cols,'Prematurity_Level')

In [298]:
features = features.withColumnRenamed('Prematurity_Level', "label")

In [299]:
va = VectorAssembler(outputCol="features", inputCols=features.columns[0:-1]) #except the last col.
df = va.transform(features).select("features", "label")

### Split and  Cache the data

In [None]:
train, test = df.randomSplit([0.8,0.2])

In [305]:
train = train.cache()
test = test.cache()

### ML Models

In [314]:
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier(maxDepth=30)
rfmodel = rf.fit(train)

In [315]:
rfpredicts = rfmodel.transform(test)

In [316]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(rfpredicts)
print("Accuracy = %0.4f" % (accuracy))

Accuracy = 0.4000


In [310]:
from pyspark.ml.classification import DecisionTreeClassifier
dt = DecisionTreeClassifier(maxDepth=20, maxBins= 32, minInstancesPerNode=1, minInfoGain = 0)

In [311]:
dtmodel = dt.fit(train)
dtpredicts = dtmodel.transform(test)
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(dtpredicts)
print("Accuracy = %0.4f" % (accuracy))

Accuracy = 0.6000
