In [None]:
from helpers import *
from nets import *
from results_and_plots import *

In [None]:
parent_dir = '/content/drive/My Drive'
data_dir = parent_dir + '/TEAA/21x14'

formats = [
    '21x14', # 21 channels with spectrogram filtered to 8 values in the frequency domain + 6 time domain statistics: in total 294 dimensions
    'pca136', # same than previous one after applying PCA conserving 99% of variance
    'pca141' # same but using chb[1-16] for training and chb[17-24] for testing
]

configurations = [
    LabConfiguration(FEEDFORWARD, feedforward_1, to_categorical = False, epochs = 5),
    LabConfiguration(FEEDFORWARD, feedforward_2, to_categorical = True, epochs = 5),
    LabConfiguration(FEEDFORWARD, feedforward_3, to_categorical = True, epochs = 5),
    LabConfiguration(RECURRENT, rnn_1, to_categorical = True, timesteps = 10, epochs = 1),
    LabConfiguration(RECURRENT, rnn_2, to_categorical = True, timesteps = 10, epochs = 1),
    LabConfiguration(CONVOLUTIONAL, cnn_1, to_categorical = False, epochs = 5),
    LabConfiguration(CONVOLUTIONAL, cnn_2, to_categorical = True, epochs = 5),
    LabConfiguration(TDCNN, cnn_3, to_categorical = True, timesteps = 10, epochs = 10)
]

In [None]:
drive.mount('/content/drive')

In [None]:
for format in formats:
  train_lines, test_lines = get_train_and_test_sets(format, data_dir)
  print(f"Read {len(train_lines)} lines from train set and {len(test_lines)} from test set")    

  for binary_classification in [True, False]:

    # get y data
    train_set = [csv_to_tuple(l, False, binary_classification) for l in train_lines]
    test_set  = [csv_to_tuple(l, False, binary_classification) for l in test_lines]
    # num_channels = train_set[0][2].shape[0]

    y_train = np.array([t[1] for t in train_set])
    y_test  = np.array([t[1] for t in test_set])
    labels = np.unique(y_train)

    class_weights = class_weight.compute_class_weight('balanced',
                                                      classes = labels,
                                                      y = y_train)
    class_weights = {i: class_weights[i] for i in range(len(class_weights))}


    for conf in configurations:
      if format != '21x24' and conf.model_type == TDCNN or conf.model_type == CONVOLUTIONAL:
        continue

      if conf.to_categorical:
        y_train = keras.utils.to_categorical(y_train, num_classes = len(labels))
        y_test  = keras.utils.to_categorical(y_test,  num_classes = len(labels))

      # get X data
      if conf.model_type == CONVOLUTIONAL:
        train_set = reshape_dataset(train_set)
        test_set = reshape_dataset(test_set)
        X_train, X_test = prepare_x(train_set, test_set, conf.model_type, convolutional = True)
        data = [(X_train, y_train), (X_test, y_test)]
      else:
        X_train, X_test = prepare_x(train_set, test_set, conf.model_type, convolutional = False)
        data = [(X_train, y_train), (X_test, y_test)]

      print(f"Preprocessed data: {len(train_set)} train set {len(test_set)} test set\nexample data: {train_set[0][0]}, label {train_set[0][1]}, x.shape = {train_set[0][2].shape}")

      # configure data generator and model
      checkpoint_path = f"{parent_dir}/TEAA/checkpoints/{format}_{conf.dnn.__name__}_{conf.epochs}_epochs_{'bin' if binary_classification else 'multi'}.ckpt"
      model_existed = False

      if conf.model_type == RECURRENT or conf.model_type == TDCNN:
        train_data_generator = MyDataGenerator(X_train, y_train,
                                              shuffle = True,
                                              timesteps = conf.timesteps,
                                              batch_size = 50,
                                              num_classes = len(labels),
                                              to_categorical = True)
        test_data_generator  = MyDataGenerator(X_test,  y_test,
                                              shuffle = False,
                                              timesteps = conf.timesteps,
                                              batch_size = 50,
                                              num_classes = len(labels),
                                              to_categorical = True)

        model = conf.dnn(input_shape = (conf.timesteps,) + X_train.shape[1:], num_labels = len(labels))

      else:
        train_data_generator = MyDataGeneratorToBalanceClasses(X_train, y_train,
                                              shuffle = True,
                                              batch_size = 100,
                                              num_classes = len(labels),
                                              to_categorical = True)
        test_data_generator = None

        model = conf.dnn(input_shape = X_train.shape[1:], num_labels = len(labels))
      if os.path.isfile(f"{checkpoint_path}.index"):
        model.load_weights(checkpoint_path)
        model_existed = True
        print("Read weights from checkpoint")


      print(model.summary())

      if not model_existed:
        # train model and save it
        checkpoint_dir = os.path.dirname(checkpoint_path)

        cp_callback = keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                        save_weights_only=True,
                                                        verbose=1)

        if conf.model_type == RECURRENT:
          model.fit(x = train_data_generator, epochs = 1, validation_data = test_data_generator,
                    use_multiprocessing = False, callbacks=[cp_callback])
        elif conf.model_type == TDCNN:
          model.fit(x = train_data_generator, epochs = 10, validation_data = test_data_generator,
                    use_multiprocessing = True, callbacks=[cp_callback])
        else:
          X_train, y_train = data[0]
          X_test, y_test = data[1]
          model.fit(x = train_data_generator, epochs = 5, validation_data = (X_test, y_test),
                    class_weight = class_weights, use_multiprocessing = True,
                    callbacks=[cp_callback])
      else:
        print("Not training, model existed")

      # get predictions
      y_train_true, y_train_pred, y_test_true, y_test_pred = get_predictions(model, conf, y_train, y_test, X_train, X_test, train_data_generator, test_data_generator)

      print(checkpoint_path)
      show_results(y_train_true, y_train_pred, labels)
      show_results(y_test_true, y_test_pred, labels)


In [None]:
drive.flush_and_unmount()