In [1]:
%matplotlib inline
import numpy as np
import itertools
from eeg_predict import do_job
from functions.utils import merge_dicts

In [2]:
# model parameters
modelName = {0: 'model_svm',
             1: 'model_custom_mlp',
             2: 'model_custom_mlp_multi',
             3: 'model_cnn_basic',
             4: 'model_cnn',
             5: 'model_cnn_max',
             6: 'model_cnn_conv1d',
             7: 'model_cnn_lstm',
             8: 'model_cnn_mix',
             9: 'model_cnn_lstm_hybrid',
             10: 'model_cnn3d',
             11: 'model_cnn3d_new'}

pixelCount = {0: 20,
              1: [4, 5],
              2: [8, 8],
              3: [16, 16]}

In [3]:
# subject list
subjectList = np.array([1, 2, 3, 5, 7, 9, 10, 13, 14, 16, 17, 18, 20, 21, 23])

In [4]:
# Train with single parameters
trainWithSingleParams = False
if trainWithSingleParams:
    do_job(job="predict", predictParams={"subject": 2, "modelName": modelName[10], "pixelCount": pixelCount[2],
                                         "preictalLen": 30, "excludedLen": 240, "onlineWeights": False, "l2": 0.0,
                                         "earlyStopping": True, "adaptiveL2": False, "evalExcluded": True})

In [5]:
#Plot train results
plotResults = False
if plotResults:
    do_job(job="plot", plotTitle= ["subject", "modelName", "preictalLen"],
           plotSet=[{"modelName": modelName[1], "pixelCount": pixelCount[0], "subject": 2, "excludedLen": 240, "evalExcluded": False}])

In [6]:
# Train with various parameters
trainWithVariousParams = False
if trainWithVariousParams:
    subjectList = [{"subject": subject} for subject in subjectList]

    modelList = [{"modelName": modelName[1], "pixelCount": pixelCount[0]},
                 {"modelName": modelName[4], "pixelCount": pixelCount[2]},
                 {"modelName": modelName[10], "pixelCount": pixelCount[2]}]

    timingList = [{"preictalLen": 30, "excludedLen": 60, "onlineWeights": False},
                  {"preictalLen": 30, "excludedLen": 120, "onlineWeights": False},
                  {"preictalLen": 30, "excludedLen": 240, "onlineWeights": False},
                  {"preictalLen": 60, "excludedLen": 120, "onlineWeights": False},
                  {"preictalLen": 60, "excludedLen": 120, "onlineWeights": True},
                  {"preictalLen": 60, "excludedLen": 240, "onlineWeights": False},
                  {"preictalLen": 60, "excludedLen": 240, "onlineWeights": True}]

    trainList = [{"l2": 0, "earlyStopping": True, "adaptiveL2": False},
                 {"l2": 0.01, "earlyStopping": True, "adaptiveL2": False}]



    trialParams = itertools.product(*[timingList, trainList, subjectList, modelList])
    for trialParam in trialParams:
        do_job(job="predict", predictParams=merge_dicts(trialParam))

In [7]:
# Compare train results
compareResults = False
if compareResults:
    for i in subjectList:
        do_job(job="compare", plotLabels=["modelName", "onlineWeights", "l2"], plotTitle=["subject"],
               compareSet=[
                   {"modelName": modelName[1], "pixelCount": pixelCount[0], "subject": i, "excludedLen": 240,
                    "adaptiveL2": False, "onlineWeights": False, "preictalLen": 60},
                   {"modelName": modelName[4], "pixelCount": pixelCount[2], "subject": i, "excludedLen": 240,
                    "adaptiveL2": False, "onlineWeights": False, "preictalLen": 60},
                   {"modelName": modelName[10], "pixelCount": pixelCount[2], "subject": i, "excludedLen": 240,
                    "adaptiveL2": False, "onlineWeights": False, "preictalLen": 60}
                   ])

In [8]:
# Compare with various parameters
compareWithVariousParams = True
if compareWithVariousParams:
    subjectList = [{"subject": subject} for subject in subjectList]

    modelList = [{"modelName": modelName[1], "pixelCount": pixelCount[0]},
                 {"modelName": modelName[4], "pixelCount": pixelCount[2]},
                 {"modelName": modelName[10], "pixelCount": pixelCount[2]}]

    timingList = [{"preictalLen": 30, "excludedLen": 60, "onlineWeights": False},
                  {"preictalLen": 30, "excludedLen": 120, "onlineWeights": False},
                  {"preictalLen": 30, "excludedLen": 240, "onlineWeights": False},
                  {"preictalLen": 60, "excludedLen": 120, "onlineWeights": False},
                  {"preictalLen": 60, "excludedLen": 120, "onlineWeights": True},
                  {"preictalLen": 60, "excludedLen": 240, "onlineWeights": False},
                  {"preictalLen": 60, "excludedLen": 240, "onlineWeights": True}]

    trainList = [{"l2": 0, "earlyStopping": True, "adaptiveL2": False},
                 {"l2": 0.01, "earlyStopping": True, "adaptiveL2": False}]



    compareParams = [timingList, trainList, subjectList, modelList]
    do_job(job="compare_various", compareParams=compareParams,
           validParams=["subject", "modelName", "preictalLen,excludedLen,onlineWeights", "l2"],
           selectSet=[ timingList,
                       subjectList,
                       [{"subject":1}]
                       ])

Selected Params:	preictalLen=30, excludedLen=60, onlineWeights=False
                       fpr                                                                           tpr                                                             
l2                    0.00                                   0.01                                   0.00                                   0.01                      
modelName model_custom_mlp model_cnn model_cnn3d model_custom_mlp model_cnn model_cnn3d model_custom_mlp model_cnn model_cnn3d model_custom_mlp model_cnn model_cnn3d
subject                                                                                                                                                              
1                 0.139988  0.034997    0.105114         0.139988  0.034997    0.105114         1.000000  1.000000    1.000000         1.000000  1.000000    1.000000
2                 0.729094  0.132563    0.232177         0.000000  0.066281    0.165841         0.666

AttributeError: 'CompareClass' object has no attribute 'preictalLen,excludedLen,onlineWeights'