In [None]:
from pyspark import SparkContext, SparkConf
from scipy.io import loadmat
import numpy as np
from pyspark.sql.types import Row
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, RandomForestClassificationModel
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql import SQLContext
from pyspark.ml.linalg import Vectors
import os
import time
from sklearn.metrics import roc_curve, auc
from pyspark.ml import Pipeline
import gcsfs
import json
from pyspark.sql.types import *
import pandas as pd


def process_test_sample(sample, with_latency, transforms): 
            name, sample_X = sample
            data = sample_X['data']
            transformed_data = transforms(data)
            return (name, transformed_data)




def predict_subjects(gs_dir, subjects, sc, num_nodes):
    
    json_str_rdd = sc.textFile(gs_dir + '/SETTINGS.json')
    json_str = ''.join(json_str_rdd.collect())
    settings = json.loads(json_str)
    
    proj_name = settings['gcp-project-name']
    proj_dir = settings['gcp-bucket-project-dir']
    dataset_dir = settings['dataset-dir']
    result_dir = settings["submission-dir"]
    model_dir = settings["data-cache-dir"]
    
    fs = gcsfs.GCSFileSystem(project = proj_name)
    
    results = []
    
    for subject in subjects:
        
        #Load data into rdd
        start_time = time.time()
        loader = dataloader('/'.join([proj_dir,dataset_dir,subject]), fs)
        test_raw, test_names = loader.load_test_data()

        partitionNum = num_nodes * 10
        test_raw_names = list(zip(test_names, test_raw))
        end_time = time.time()
        print('--- '+ subject + ": Test Data Loading %s seconds ---" % (end_time - start_time))
        #Data preprocessing and transformation
        start_time = time.time()
        test_rdd = sc.parallelize(test_raw_names, partitionNum)
        transformed_test_rdd = test_rdd.map(lambda x: process_test_sample(x, True, sample_transform)).cache()
        #transformed_interictal_rdd = interictal_rdd.map(lambda x: process_raw_sample(x, False, sample_transform)).cache()

        def rddToDf(x):
            '''Convert rdd to  and pass this function in Row() args'''
            name, sample_X = x
            d = {}
            d['clip'] = name
            d['features'] = Vectors.dense(sample_X)
            return d

        test_df = transformed_test_rdd.map(lambda x: Row(**rddToDf(x))).toDF()

        test_df.cache()
        

        end_time = time.time()
        
        print('--- '+ subject + ": Test Data Transformation %s seconds ---" % (end_time - start_time))

        #Predicting samples with saved models or retrain a new model for prediction if not exists
        
        model = load_model(gs_dir, subject)
        if not model:
            model = train_model(gs_dir, [subject], sc, fs, num_nodes)[0]
        start_time = time.time()
        result = model.transform(test_df)
        end_time = time.time()
        print('--- '+ subject + ": Making Predictions %s seconds ---" % (end_time - start_time))
        print('--- '+ subject + ": Saving Trained Model ---" )
        results.append(result)


    return results
        
    
def generate_pred_result(results, gs_dir, sc):
    
    def resultRddToDf(x_and_name):
        '''Convert rdd to  and pass this function in Row() args'''
        name, x = x_and_name
        d = {}
        d['clip'] = name
        d['seizure'] = float(x[1] + x[2])
        d['early'] = float(x[2])
        return d

    submission_df = results[0].select(['clip','probability']).rdd.map(lambda x: Row(**resultRddToDf(x))).toDF()
    
    for i in range(1, len(results)):
        result_prob_df = result_df.select(['clip','probability']).rdd.map(lambda x: Row(**resultRddToDf(x))).toDF()
        submission_df = submission_df.unionAll(result_prob_df)
    return submission_df
    

#Main Function for testing 
num_nodes = 2
subjects = ['Patient_8']
gs_dir = "gs://seizure_detection_data/notebooks/seizure_detection_spark_gcp"

appName = 'seizure_detection'
conf = SparkConf().setAppName(appName).setMaster('local')
conf = (conf.setMaster('local[*]')
        .set("spark.executor.instances", str(2 * num_nodes))
        .set('spark.executor.memory', '15G')
        .set('spark.driver.memory', '15G')
        .set('spark.driver.maxResultSize', '15G'))
try:
    sc.stop()
except:
    pass
sc = SparkContext(conf = conf)

sqlContext = SQLContext(sc)
    
json_str_rdd = sc.textFile(gs_dir + '/SETTINGS.json')
json_str = ''.join(json_str_rdd.collect())
settings = json.loads(json_str)

proj_name = settings['gcp-project-name']
proj_dir = settings['gcp-bucket-project-dir']
result_dir = settings["submission-dir"]

fs = gcsfs.GCSFileSystem(project=proj_name)
fopen = fs.open(proj_dir + '/spark_data_io.py')
exec(fopen.read())
fopen.close()
fopen = fs.open(proj_dir + '/spark_transform.py')
exec(fopen.read())
fopen.close()
fopen = fs.open(proj_dir + '/spark_processing.py')
exec(fopen.read())
fopen.close()

results = predict_subjects(gs_dir, subjects, sc, num_nodes)
if len(results) == 12:
    #If using full dataset from all 12 subject folders, generate a submission file
    submission_df = generate_pred_result(results, gs_dir, sc).write.csv()
    with fs.open('/'.join([proj_dir, result_dir,'submissions.csv']), 'w') as f:   
        submission_df = test_submission_df.toPandas()
        submission_df['clip'] = submission_df['clip'] + '.mat'
        submission_df.to_csv(f, index = False)

