In [55]:
import json
import os
import numpy as np
import pandas as pd
import tensorflow as tf


from hyperopt import fmin, tpe, hp, STATUS_OK, Trials, space_eval
from src.data_processing.pipelines.ClassifierPipe import ClassifierPipe
from src.utilities.os_helpers import set_up_directories
from src.data_processing.processors.TrainingProcessor import TrainingProcessor
from src.models.BaseClassifier import BaseClassifier

ImportError: cannot import name 'tfa' from 'tensorflow' (/Users/mds8301/anaconda3/envs/enigma/lib/python3.11/site-packages/tensorflow/__init__.py)

In [59]:

DATA_PATH = "/Users/mds8301/iterm_data_storage/raw_data_raw_data.parquet.gzip"
MAIN_DIR = "/Users/mds8301/iterm_data_storage"
EXPERIMENT_NAME = "base_classifier_tuning"

# path to experiment directory
EXPERIMENT_DIR = os.path.join(MAIN_DIR, EXPERIMENT_NAME)
set_up_directories(EXPERIMENT_DIR)

processor_pipe = (ClassifierPipe(DATA_PATH)
                  .read_raw_data()
                  .calculate_max_min_signal()
                  .split_data(test_size=0.3,
                              test_dev_size=0.5,
                              split_group="mouse_id",
                              stratify_group="sex",
                              target='action',
                              save_subject_ids=True,
                              path_to_save=os.path.dirname(DATA_PATH))
                  .transorm_data(numeric_target_dict={'avoid': 1, 'escape': 0})
                  )


space = {
    'number_of_layers': hp.choice('number_of_layers', [2,3]),
    'number_of_units': hp.choice('number_of_units', [5, 10]),
    "dropout_rate": hp.choice('dropout_rate', [0.1, 0.2]),
    "learning_rate": hp.choice('learning_rate', [0.0001, 0.1]),
    "batch_size": hp.choice('batch_size', [32, 64]),
    "epochs": hp.choice('epochs', [1,2,3]),
    "optimizers": hp.choice('optimizers', ['adam', 'sgd'])

}

trials = Trials()


def objective(params):

    number_of_layers = params['number_of_layers']
    number_of_units = params['number_of_units']
    dropout_rate = params['dropout_rate']
    learning_rate = params['learning_rate']
    batch_size = params['batch_size']
    epochs = params['epochs']
    optimizer = params['optimizers']

    # set up model
    model = BaseClassifier(number_of_layers, number_of_units, dropout_rate)

    if optimizer == "adam":
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    elif optimizer == "sgd":
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
        


    metrics = [tf.keras.metrics.BinaryAccuracy(name='accuracy'),
               tf.keras.metrics.Precision(name='precision'),
               tf.keras.metrics.Recall(name='recall'),
               tf.keras.metrics.AUC(name='auc-roc')]

    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy', metrics=metrics)

    # train model
    model.fit(processor_pipe.X_train,
              processor_pipe.y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_data=(processor_pipe.X_dev, processor_pipe.y_dev)
              )

    evaluation = model.evaluate(processor_pipe.X_test, processor_pipe.y_test)

    def calculate_f1_score(precission, recall):
        f1 = 2 * (precission * recall) / (precission + recall)
        return f1
        

    all_results = {}
    all_results['params'] = params
    for name, value in zip(model.metrics_names, evaluation):
        all_results[name] = value
    # results['model'] = model
    all_results['f1_score'] = calculate_f1_score(evaluation[1], evaluation[2])
    all_results['status'] = STATUS_OK
    evaluation.append(all_results['f1_score'])

    with open(os.path.join(EXPERIMENT_DIR, 'results.json'), 'a+') as f:
        json.dump(all_results, f, indent=1)
    
    f1_score = evaluation[-1]


    return -1 * f1_score


def run_trials():
    best_trials = fmin(objective,
                       space=space,
                       algo=tpe.suggest,
                       max_evals=10,
                       trials=trials)
    best_params = space_eval(space, best_trials)
    
    with open(os.path.join(EXPERIMENT_DIR, 'best_params.json'), 'a+') as f:
        json.dump(best_params, f)

    return best_trials
best_trials = run_trials()


  0%|          | 0/10 [00:00<?, ?trial/s, best loss=?]





Epoch 1/2                                             

  1/127 [..............................] - ETA: 2:36 - loss: 1.1596 - accuracy: 0.3438 - precision: 0.4615 - recall: 0.3000 - auc-roc: 0.1958
  2/127 [..............................] - ETA: 34s - loss: 0.9225 - accuracy: 0.4531 - precision: 0.5476 - recall: 0.5897 - auc-roc: 0.3826 
  3/127 [..............................] - ETA: 26s - loss: 0.8171 - accuracy: 0.5104 - precision: 0.5775 - recall: 0.7069 - auc-roc: 0.5154
  4/127 [..............................] - ETA: 25s - loss: 0.7226 - accuracy: 0.5859 - precision: 0.6374 - recall: 0.7436 - auc-roc: 0.6064
  5/127 [>.............................] - ETA: 23s - loss: 0.6857 - accuracy: 0.6125 - precision: 0.6727 - recall: 0.7400 - auc-roc: 0.6471
  6/127 [>.............................] - ETA: 24s - loss: 0.6881 - accuracy: 0.6302 - precision: 0.6818 - recall: 0.7563 - auc-roc: 0.6591
  7/127 [>.............................] - ETA: 21s - loss: 0.6747 - accuracy: 0.6429 - precisio





Epoch 1/2                                                                        

  1/127 [..............................] - ETA: 2:42 - loss: 0.8258 - accuracy: 0.4688 - precision: 0.5556 - recall: 0.5263 - auc-roc: 0.4575
  2/127 [..............................] - ETA: 38s - loss: 0.7792 - accuracy: 0.5781 - precision: 0.6279 - recall: 0.7105 - auc-roc: 0.5071 
  3/127 [..............................] - ETA: 31s - loss: 0.8068 - accuracy: 0.5208 - precision: 0.5672 - recall: 0.6909 - auc-roc: 0.4565
  4/127 [..............................] - ETA: 26s - loss: 0.7834 - accuracy: 0.5156 - precision: 0.5682 - recall: 0.6757 - auc-roc: 0.4702
  5/127 [>.............................] - ETA: 22s - loss: 0.7893 - accuracy: 0.5063 - precision: 0.5536 - recall: 0.6813 - auc-roc: 0.4496
  6/127 [>.............................] - ETA: 24s - loss: 0.7908 - accuracy: 0.4948 - precision: 0.5379 - recall: 0.6636 - auc-roc: 0.4501
  7/127 [>.............................] - ETA: 21s - loss: 0.7749 - 





Epoch 1/2                                                                        

  1/127 [..............................] - ETA: 2:46 - loss: 0.6887 - accuracy: 0.5625 - precision: 0.6154 - recall: 0.4706 - auc-roc: 0.6314
  2/127 [..............................] - ETA: 36s - loss: 0.6717 - accuracy: 0.5625 - precision: 0.6667 - recall: 0.5263 - auc-roc: 0.6356 
  3/127 [..............................] - ETA: 32s - loss: 0.7063 - accuracy: 0.5625 - precision: 0.6304 - recall: 0.5370 - auc-roc: 0.6016
  4/127 [..............................] - ETA: 29s - loss: 0.6966 - accuracy: 0.5547 - precision: 0.5965 - recall: 0.5000 - auc-roc: 0.6165
  5/127 [>.............................] - ETA: 25s - loss: 0.6894 - accuracy: 0.5625 - precision: 0.6125 - recall: 0.5568 - auc-roc: 0.6314
  6/127 [>.............................] - ETA: 24s - loss: 0.6986 - accuracy: 0.5729 - precision: 0.6170 - recall: 0.5577 - auc-roc: 0.6239
  7/127 [>.............................] - ETA: 22s - loss: 0.7022 - 





Epoch 1/3                                                                        

 1/64 [..............................] - ETA: 1:07 - loss: 0.7911 - accuracy: 0.6562 - precision: 0.6829 - recall: 0.7568 - auc-roc: 0.6241
 2/64 [..............................] - ETA: 13s - loss: 0.7860 - accuracy: 0.6562 - precision: 0.6778 - recall: 0.8026 - auc-roc: 0.6297 
 3/64 [>.............................] - ETA: 11s - loss: 0.7782 - accuracy: 0.6458 - precision: 0.6714 - recall: 0.8103 - auc-roc: 0.6272
 4/64 [>.............................] - ETA: 10s - loss: 0.7719 - accuracy: 0.6211 - precision: 0.6576 - recall: 0.7806 - auc-roc: 0.6232
 5/64 [=>............................] - ETA: 10s - loss: 0.7477 - accuracy: 0.6313 - precision: 0.6709 - recall: 0.7929 - auc-roc: 0.6301
 6/64 [=>............................] - ETA: 9s - loss: 0.7470 - accuracy: 0.6302 - precision: 0.6691 - recall: 0.7778 - auc-roc: 0.6295 
 7/64 [==>...........................] - ETA: 10s - loss: 0.7357 - accuracy: 0.63





Epoch 1/2                                                                        

 1/64 [..............................] - ETA: 1:19 - loss: 0.6446 - accuracy: 0.7031 - precision: 0.8235 - recall: 0.6829 - auc-roc: 0.6787
 2/64 [..............................] - ETA: 14s - loss: 0.6930 - accuracy: 0.6484 - precision: 0.7183 - recall: 0.6711 - auc-roc: 0.6259 
 3/64 [>.............................] - ETA: 14s - loss: 0.6809 - accuracy: 0.6510 - precision: 0.7143 - recall: 0.6957 - auc-roc: 0.6415
 4/64 [>.............................] - ETA: 13s - loss: 0.6972 - accuracy: 0.6211 - precision: 0.6939 - recall: 0.6623 - auc-roc: 0.6218
 5/64 [=>............................] - ETA: 12s - loss: 0.7041 - accuracy: 0.6031 - precision: 0.6684 - recall: 0.6579 - auc-roc: 0.6081
 6/64 [=>............................] - ETA: 12s - loss: 0.7038 - accuracy: 0.6042 - precision: 0.6592 - recall: 0.6592 - auc-roc: 0.6123
 7/64 [==>...........................] - ETA: 11s - loss: 0.7060 - accuracy: 0.60





Epoch 1/3                                                                        

 1/64 [..............................] - ETA: 1:09 - loss: 0.7235 - accuracy: 0.3750 - precision: 0.3947 - recall: 0.4688 - auc-roc: 0.3770
 2/64 [..............................] - ETA: 20s - loss: 0.7151 - accuracy: 0.4453 - precision: 0.5065 - recall: 0.5417 - auc-roc: 0.4084 
 3/64 [>.............................] - ETA: 20s - loss: 0.7261 - accuracy: 0.3958 - precision: 0.4821 - recall: 0.4821 - auc-roc: 0.3493
 4/64 [>.............................] - ETA: 18s - loss: 0.7236 - accuracy: 0.4141 - precision: 0.4897 - recall: 0.4830 - auc-roc: 0.3661
 5/64 [=>............................] - ETA: 16s - loss: 0.7228 - accuracy: 0.4156 - precision: 0.5055 - recall: 0.4868 - auc-roc: 0.3558
 6/64 [=>............................] - ETA: 15s - loss: 0.7205 - accuracy: 0.4167 - precision: 0.5071 - recall: 0.4714 - auc-roc: 0.3614
 7/64 [==>...........................] - ETA: 13s - loss: 0.7168 - accuracy: 0.43





Epoch 1/3                                                                        

  1/127 [..............................] - ETA: 2:36 - loss: 0.6863 - accuracy: 0.5000 - precision: 0.6818 - recall: 0.6250 - auc-roc: 0.4115
  2/127 [..............................] - ETA: 29s - loss: 0.7212 - accuracy: 0.4844 - precision: 0.6190 - recall: 0.6047 - auc-roc: 0.4114 
  3/127 [..............................] - ETA: 34s - loss: 0.6996 - accuracy: 0.5312 - precision: 0.6515 - recall: 0.6615 - auc-roc: 0.4452
  4/127 [..............................] - ETA: 27s - loss: 0.7280 - accuracy: 0.5078 - precision: 0.6250 - recall: 0.6471 - auc-roc: 0.3881
  5/127 [>.............................] - ETA: 28s - loss: 0.7289 - accuracy: 0.5125 - precision: 0.6216 - recall: 0.6571 - auc-roc: 0.3913
  6/127 [>.............................] - ETA: 26s - loss: 0.7331 - accuracy: 0.5000 - precision: 0.6031 - recall: 0.6423 - auc-roc: 0.4033
  7/127 [>.............................] - ETA: 25s - loss: 0.7461 - 





Epoch 1/3                                                                        

  1/127 [..............................] - ETA: 2:40 - loss: 0.9093 - accuracy: 0.2812 - precision: 0.4444 - recall: 0.3810 - auc-roc: 0.2727
  2/127 [..............................] - ETA: 29s - loss: 0.7733 - accuracy: 0.4844 - precision: 0.5897 - recall: 0.5750 - auc-roc: 0.4802 
  3/127 [..............................] - ETA: 23s - loss: 0.8050 - accuracy: 0.4688 - precision: 0.5469 - recall: 0.6140 - auc-roc: 0.4606
  4/127 [..............................] - ETA: 22s - loss: 0.8131 - accuracy: 0.4688 - precision: 0.5057 - recall: 0.6377 - auc-roc: 0.4786
  5/127 [>.............................] - ETA: 20s - loss: 0.8426 - accuracy: 0.4688 - precision: 0.5133 - recall: 0.6591 - auc-roc: 0.4312
  6/127 [>.............................] - ETA: 21s - loss: 0.8498 - accuracy: 0.4583 - precision: 0.5188 - recall: 0.6330 - auc-roc: 0.3993
  7/127 [>.............................] - ETA: 18s - loss: 0.8497 - 





 1/64 [..............................] - ETA: 1:18 - loss: 0.9214 - accuracy: 0.4844 - precision: 0.6667 - recall: 0.1667 - auc-roc: 0.3963
 2/64 [..............................] - ETA: 16s - loss: 0.9695 - accuracy: 0.4375 - precision: 0.6667 - recall: 0.1053 - auc-roc: 0.3694 
 3/64 [>.............................] - ETA: 13s - loss: 0.9319 - accuracy: 0.4583 - precision: 0.6000 - recall: 0.1376 - auc-roc: 0.3911
 4/64 [>.............................] - ETA: 12s - loss: 0.9514 - accuracy: 0.4297 - precision: 0.5484 - recall: 0.1141 - auc-roc: 0.3973
 5/64 [=>............................] - ETA: 11s - loss: 0.9513 - accuracy: 0.4094 - precision: 0.5238 - recall: 0.1152 - auc-roc: 0.4159
 6/64 [=>............................] - ETA: 10s - loss: 0.9556 - accuracy: 0.4089 - precision: 0.4815 - recall: 0.1156 - auc-roc: 0.3896
 7/64 [==>...........................] - ETA: 10s - loss: 0.9597 - accuracy: 0.3906 - precision: 0.4348 - recall: 0.1136 - auc-roc: 0.3799
 8/64 [==>...............





Epoch 1/2                                                                        

  1/127 [..............................] - ETA: 2:06 - loss: 1.2556 - accuracy: 0.3750 - precision: 0.4706 - recall: 0.4211 - auc-roc: 0.2733
  2/127 [..............................] - ETA: 27s - loss: 1.2579 - accuracy: 0.4219 - precision: 0.5625 - recall: 0.4390 - auc-roc: 0.2869 
  3/127 [..............................] - ETA: 19s - loss: 1.1902 - accuracy: 0.4167 - precision: 0.5385 - recall: 0.4667 - auc-roc: 0.3060
  5/127 [>.............................] - ETA: 15s - loss: 1.1803 - accuracy: 0.4062 - precision: 0.5057 - recall: 0.4583 - auc-roc: 0.3015
  6/127 [>.............................] - ETA: 14s - loss: 1.2169 - accuracy: 0.3854 - precision: 0.4904 - recall: 0.4397 - auc-roc: 0.2718
  7/127 [>.............................] - ETA: 14s - loss: 1.2307 - accuracy: 0.3795 - precision: 0.4793 - recall: 0.4328 - auc-roc: 0.2612
  8/127 [>.............................] - ETA: 13s - loss: 1.2369 - 