In [17]:
# General imports
import os
import time
import random
import datetime
import glob
import pickle
import tqdm
import copy
import optuna
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from rich import print as rprint
from rich.pretty import pprint as rpprint
from tqdm import tqdm
from itertools import chain
from functools import partial

# JAX + Keras
os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_USE_LEGACY_KERAS"] = "0"
import jax
import jax.numpy as jnp
import keras
from keras.models import Model
from keras.layers import Dense, Activation, Permute, Dropout
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Conv1D, MaxPooling1D, AveragePooling1D
from keras.layers import SeparableConv2D, DepthwiseConv2D
from keras.layers import BatchNormalization
from keras.layers import SpatialDropout2D
from keras.regularizers import l1_l2
from keras.layers import Input, Flatten
from keras.constraints import max_norm
from keras import backend as K

# Sklearn
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout, TimeDistributed, BatchNormalization
import h5py
from sklearn import preprocessing

# Dataset
from custom_datasets.fatigue_mi import FatigueMI

%load_ext autoreload
%autoreload 3

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
SKLRNG = 42
RNG = jax.random.PRNGKey(SKLRNG)

#### Dataset

In [21]:
def data_generator(dataset, subjects = [1], channel_idx = [], filters = ([8, 32],), sfreq = 250):
    
    find_events = lambda raw, event_id: mne.find_events(raw, shortest_event=0, verbose=False) if len(mne.utils._get_stim_channel(None, raw.info, raise_error=False)) > 0 else mne.events_from_annotations(raw, event_id=event_id, verbose=False)[0]
    
    data = dataset.get_data(subjects=(subjects if len(subjects) > 0 else dataset.subject_list))
    
    X = []
    y = []
    metadata = []

    for subject_id in data.keys():
        for session_id in data[subject_id].keys():
            for run_id in data[subject_id][session_id].keys():
                raw = data[subject_id][session_id][run_id]
                
                for fmin, fmax in filters:
                    raw = raw.filter(l_freq = fmin, h_freq = fmax, method = 'iir', picks = 'eeg', verbose = False)
                
                events = find_events(raw, dataset.event_id)

                tmin = dataset.interval[0]
                tmax = dataset.interval[1]

                channels = np.asarray(raw.info['ch_names'])[channel_idx] if len(channel_idx) > 0 else np.asarray(raw.info['ch_names'])

                # rpprint(channels)
                
                stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
                picks = mne.pick_channels(raw.info["ch_names"], include=channels, exclude=stim_channels, ordered=True)

                x = mne.Epochs(
                    raw,
                    events,
                    event_id=dataset.event_id,
                    tmin=tmin,
                    tmax=tmax,
                    proj=False,
                    baseline=None,
                    preload=True,
                    verbose=False,
                    picks=picks,
                    event_repeated="drop",
                    on_missing="ignore",
                )
                x_events = x.events
                inv_events = {k: v for v, k in dataset.event_id.items()}
                labels = [inv_events[e] for e in x_events[:, -1]]

                # rpprint({
                #     "X": np.asarray(x.get_data(copy=False)).shape,
                #     "y": np.asarray(labels).shape,
                #     "channels selected": np.asarray(raw.info['ch_names'])[channel_idx]
                # })

                # x.plot(scalings="auto")
                
                x_resampled = x.resample(sfreq) if sfreq != raw.info['sfreq'] else x # Resampler_Epoch
                x_resampled_data = x_resampled.get_data(copy=False) # Convert_Epoch_Array
                x_resampled_data_standard_scaler = np.asarray([
                    StandardScaler().fit_transform(x_resampled_data[i])
                    for i in np.arange(x_resampled_data.shape[0])
                ]) # Standard_Scaler_Epoch

                n = x_resampled_data_standard_scaler.shape[0] if sfreq != 250 else x.get_data(copy=False).shape[0]
                # n = x.get_data(copy=False).shape[0]
                met = pd.DataFrame(index=range(n))
                met["subject"] = subject_id
                met["session"] = session_id
                met["run"] = run_id
                x.metadata = met.copy()
                
                # X.append(x_resampled_data_standard_scaler)
                X.append(x)
                y.append(labels)
                metadata.append(met)

    return np.concatenate(X, axis=0), np.concatenate(y), pd.concat(metadata, ignore_index=True)

fat_dataset = FatigueMI()
X, y, _ = data_generator(fat_dataset, subjects=[1], channel_idx=[], sfreq=128) # if channel_idx=[], all channels will be used

Adding metadata with 3 columns


#### Models

In [22]:
def lstm_cnn_net(nb_classes, channels, samples, dropout_rate_1 = 0.5, dropout_rate_2 = 0.5):
    model = Sequential()
    model.add(Input(shape=(samples, channels)))

    # add 1-layer cnn
    model.add(Conv1D(40, kernel_size=20, strides=4))
    model.add(Activation('relu'))
    model.add(Dropout(dropout_rate_1))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=4, strides=4))


    # add 1-layer lstm
    model.add(LSTM(50, return_sequences=True, stateful=False))
    # model.add(LSTM(128, return_sequences=True, stateful=False))
    # model.add(LSTM(32, return_sequences=True, stateful=False))
    model.add(Dropout(dropout_rate_2))
    model.add(BatchNormalization())
    model.add(Flatten())
    model.add(Dense(nb_classes, activation='softmax'))
    
    return model


#### Heuristic search for hyperparameter optimization

In [28]:
def objective_fn(trial, subjects=[], model_str = "lstm_cnn_net", **kwargs):
    _SFREQ_ = 128
    _TRAIN_SIZE_ = 0.8
    _TEST_SIZE_ = 1 - _TRAIN_SIZE_
    _BATCH_SIZE = 256

    models = {
        "lstm_cnn_net": lstm_cnn_net
    }

    subjects = subjects if len(subjects) > 0 else fat_dataset.subject_list
    model_fn = models[model_str]
    
    channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
    channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    while len(channels_idx) == 0:
        channels_dict = { k: trial.suggest_categorical(f"channels_{k}", [True, False]) for k in fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'][:-1] }
        channels_idx = [i for i, v in enumerate(channels_dict.values()) if v]
    
    X, y, _ = data_generator(fat_dataset, subjects=subjects, channel_idx=channels_idx, sfreq=_SFREQ_)
    
    _NUM_SAMPLES_ = X.shape[-1]
    _NUM_CHANNELS_ = X.shape[-2]
    
    y_encoded = LabelEncoder().fit_transform(y)
    _NUM_CLASSES_ = len(np.unique(y_encoded))
    X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=None, shuffle=False)
    # X_train, X_test, y_train, y_test = train_test_split(X.reshape(-1, _NUM_SAMPLES_, _NUM_CHANNELS_), y_encoded, test_size=_TEST_SIZE_, random_state=SKLRNG, stratify=y_encoded)
    
    model_params = {
        "nb_classes": _NUM_CLASSES_,
        "channels": _NUM_CHANNELS_,
        "samples": _NUM_SAMPLES_,
    }
    
    model = model_fn(**model_params)
    model.compile(
        loss="sparse_categorical_crossentropy", 
        optimizer="adam", 
        metrics=["accuracy"]
    )
    
    history = model.fit(
        X_train,
        y_train,
        batch_size=_BATCH_SIZE,
        epochs=100,
        validation_split=0.2,
        shuffle=True,
        callbacks=[
            # keras.callbacks.EarlyStopping(monitor="val_loss", patience=75, restore_best_weights=True),
            # keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=75, factor=0.5)
            keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.1)
        ],
    )

    trial.set_user_attr("trial_data", {
        "channels_selected": np.asarray(fat_dataset.get_data(subjects=[1])[1]['0']['0'].info['ch_names'])[channels_idx],
        # "history": history,
        "model": model.to_json(),
        "val_accuracy": history.history["val_accuracy"],
        "test_accuracy": model.evaluate(X_test, y_test, batch_size=_BATCH_SIZE)[1]
    })

    # return (1 - np.mean(history.history["val_accuracy"][-2:])) + 0.005*len(channels_idx)
    return (1 - history.history["val_accuracy"][-1])**2 + 5e-5*len(channels_idx)

def heuristic_optimizer(obj_fn, max_iter = 15, show_progress_bar = True, subjects=[], model_str="lstm_cnn_net", **kwargs):
    
    optuna.logging.set_verbosity(optuna.logging.CRITICAL)

    objective = partial(obj_fn, subjects = subjects, model_str = model_str)
    
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.NSGAIISampler())
    study.optimize(objective, n_trials=max_iter, show_progress_bar=show_progress_bar, callbacks=[], **kwargs)
    
    return study

#### Search for best subjects

In [29]:
subject_results = {}

for subject in fat_dataset.subject_list:
   study = heuristic_optimizer(objective_fn, subjects=[subject], model_str="lstm_cnn_net")
   subject_results[subject] = study.best_trial
   rpprint({
      "subject": subject,
      "test_accuracy": study.best_trial.user_attrs["trial_data"]["test_accuracy"],
      "val_accuracy": study.best_trial.user_attrs["trial_data"]["val_accuracy"][-1],
      "channels_selected": study.best_trial.user_attrs["trial_data"]["channels_selected"]
   })

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5588 - loss: 1.0788 - val_accuracy: 0.6111 - val_loss: 0.6035 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.5441 - loss: 0.9715 - val_accuracy: 0.5556 - val_loss: 0.5984 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5147 - loss: 1.0371 - val_accuracy: 0.6111 - val_loss: 0.6048 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5588 - loss: 1.1292 - val_accuracy: 0.6111 - val_loss: 0.6102 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5147 - loss: 0.9492 - val_accuracy: 0.6111 - val_loss: 0.6114 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5147 - loss: 1.0946 - val_accuracy: 0.5000 - val_loss: 0.6830 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.5294 - loss: 0.9939 - val_accuracy: 0.5000 - val_loss: 0.6881 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5294 - loss: 1.0225 - val_accuracy: 0.5556 - val_loss: 0.6854 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.5735 - loss: 0.9691 - val_accuracy: 0.5000 - val_loss: 0.6873 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.4853 - loss: 1.0327 - val_accuracy: 0.5000 - val_loss: 0.6865 - learning_rate: 1.0000e-04
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5735 - loss: 0.8795 - val_accuracy: 0.4444 - val_loss: 0.7523 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.5000 - loss: 0.9143 - val_accuracy: 0.3889 - val_loss: 0.7474 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5882 - loss: 0.9389 - val_accuracy: 0.3889 - val_loss: 0.7520 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.4412 - loss: 0.9370 - val_accuracy: 0.4444 - val_loss: 0.7401 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.6176 - loss: 0.7643 - val_accuracy: 0.3889 - val_loss: 0.7366 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4412 - loss: 1.2423 - val_accuracy: 0.5000 - val_loss: 0.7175 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.4853 - loss: 1.2002 - val_accuracy: 0.5556 - val_loss: 0.7021 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.4706 - loss: 1.1696 - val_accuracy: 0.5556 - val_loss: 0.7049 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5000 - loss: 0.9665 - val_accuracy: 0.5556 - val_loss: 0.7111 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5294 - loss: 0.8399 - val_accuracy: 0.5556 - val_loss: 0.7172 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4706 - loss: 0.9695 - val_accuracy: 0.2778 - val_loss: 0.7813 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.5441 - loss: 0.9002 - val_accuracy: 0.4444 - val_loss: 0.7597 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5294 - loss: 1.0615 - val_accuracy: 0.4444 - val_loss: 0.7392 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.4853 - loss: 0.9366 - val_accuracy: 0.5556 - val_loss: 0.7220 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5735 - loss: 0.9685 - val_accuracy: 0.5556 - val_loss: 0.7090 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5441 - loss: 0.9581 - val_accuracy: 0.5556 - val_loss: 0.6997 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step - accuracy: 0.6176 - loss: 0.8691 - val_accuracy: 0.5556 - val_loss: 0.6971 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.5588 - loss: 1.0241 - val_accuracy: 0.5556 - val_loss: 0.6927 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.5882 - loss: 0.8979 - val_accuracy: 0.5556 - val_loss: 0.6919 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.6029 - loss: 0.8375 - val_accuracy: 0.5556 - val_loss: 0.6895 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5000 - loss: 1.3139 - val_accuracy: 0.5556 - val_loss: 0.7172 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.5441 - loss: 0.8888 - val_accuracy: 0.5556 - val_loss: 0.7092 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5294 - loss: 0.8819 - val_accuracy: 0.6111 - val_loss: 0.6997 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5735 - loss: 0.9054 - val_accuracy: 0.6111 - val_loss: 0.6891 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6029 - loss: 0.7732 - val_accuracy: 0.6667 - val_loss: 0.6757 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5294 - loss: 0.9438 - val_accuracy: 0.7222 - val_loss: 0.6201 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.3971 - loss: 1.2512 - val_accuracy: 0.7222 - val_loss: 0.6342 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5735 - loss: 0.7686 - val_accuracy: 0.6667 - val_loss: 0.6420 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.5000 - loss: 0.8168 - val_accuracy: 0.6667 - val_loss: 0.6478 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.5588 - loss: 0.8193 - val_accuracy: 0.6111 - val_loss: 0.6527 - learning_rate: 1.0000e-04
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5147 - loss: 1.0764 - val_accuracy: 0.5000 - val_loss: 0.7137 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.4412 - loss: 1.1064 - val_accuracy: 0.5000 - val_loss: 0.7205 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5147 - loss: 1.0103 - val_accuracy: 0.5556 - val_loss: 0.7222 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.6618 - loss: 0.7525 - val_accuracy: 0.5556 - val_loss: 0.7187 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.5735 - loss: 0.7454 - val_accuracy: 0.5556 - val_loss: 0.7202 - learning_rate: 1.0000e-04
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5000 - loss: 1.0679 - val_accuracy: 0.4444 - val_loss: 0.8649 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 43ms/step - accuracy: 0.5294 - loss: 0.9000 - val_accuracy: 0.3889 - val_loss: 0.8452 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.4706 - loss: 1.0769 - val_accuracy: 0.4444 - val_loss: 0.8232 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6324 - loss: 0.7392 - val_accuracy: 0.3889 - val_loss: 0.8097 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.5735 - loss: 0.7937 - val_accuracy: 0.3889 - val_loss: 0.8036 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4853 - loss: 1.0442 - val_accuracy: 0.4444 - val_loss: 0.7251 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.5441 - loss: 1.0374 - val_accuracy: 0.5556 - val_loss: 0.7041 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.4706 - loss: 1.1364 - val_accuracy: 0.5556 - val_loss: 0.7043 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.4706 - loss: 1.0874 - val_accuracy: 0.5000 - val_loss: 0.7120 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.6324 - loss: 0.8533 - val_accuracy: 0.3889 - val_loss: 0.7135 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.3971 - loss: 1.0791 - val_accuracy: 0.5000 - val_loss: 0.7338 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.4559 - loss: 1.1578 - val_accuracy: 0.5000 - val_loss: 0.7409 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.4853 - loss: 1.0580 - val_accuracy: 0.5000 - val_loss: 0.7384 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5441 - loss: 1.0398 - val_accuracy: 0.5000 - val_loss: 0.7419 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step - accuracy: 0.5735 - loss: 0.8100 - val_accuracy: 0.5000 - val_loss: 0.7421 - learning_rate: 1.0000e-04
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.6029 - loss: 1.0739 - val_accuracy: 0.5000 - val_loss: 0.7322 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.4853 - loss: 1.0662 - val_accuracy: 0.4444 - val_loss: 0.7466 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.4706 - loss: 1.1390 - val_accuracy: 0.4444 - val_loss: 0.7512 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.4706 - loss: 1.0349 - val_accuracy: 0.4444 - val_loss: 0.7538 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step - accuracy: 0.4706 - loss: 1.0858 - val_accuracy: 0.4444 - val_loss: 0.7547 - learning_rate: 1.0000e-04
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4559 - loss: 1.1210 - val_accuracy: 0.4444 - val_loss: 0.8035 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.5000 - loss: 0.9270 - val_accuracy: 0.4444 - val_loss: 0.7972 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5000 - loss: 1.1133 - val_accuracy: 0.4444 - val_loss: 0.7892 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.5882 - loss: 0.8753 - val_accuracy: 0.5000 - val_loss: 0.7838 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.5735 - loss: 0.8545 - val_accuracy: 0.5000 - val_loss: 0.7776 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4559 - loss: 1.1121 - val_accuracy: 0.3889 - val_loss: 0.7303 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.4559 - loss: 1.0179 - val_accuracy: 0.5000 - val_loss: 0.7410 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.5588 - loss: 0.8531 - val_accuracy: 0.5000 - val_loss: 0.7477 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5588 - loss: 0.9158 - val_accuracy: 0.4444 - val_loss: 0.7539 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.3824 - loss: 1.2349 - val_accuracy: 0.3889 - val_loss: 0.7537 - learning_rate: 1.0000e-04
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5735 - loss: 1.1200 - val_accuracy: 0.4444 - val_loss: 0.7579 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.5147 - loss: 1.1674 - val_accuracy: 0.5556 - val_loss: 0.7284 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.6324 - loss: 1.0970 - val_accuracy: 0.4444 - val_loss: 0.7357 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5882 - loss: 0.8071 - val_accuracy: 0.5556 - val_loss: 0.7223 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5735 - loss: 0.8842 - val_accuracy: 0.5000 - val_loss: 0.7207 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4559 - loss: 1.2202 - val_accuracy: 0.3889 - val_loss: 0.7256 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.4559 - loss: 0.9964 - val_accuracy: 0.5000 - val_loss: 0.7169 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.5588 - loss: 0.8232 - val_accuracy: 0.5556 - val_loss: 0.7077 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.5147 - loss: 1.0902 - val_accuracy: 0.5000 - val_loss: 0.7073 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5000 - loss: 1.0182 - val_accuracy: 0.5000 - val_loss: 0.7080 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4853 - loss: 1.2094 - val_accuracy: 0.4444 - val_loss: 0.7680 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.4853 - loss: 1.2047 - val_accuracy: 0.4444 - val_loss: 0.7628 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6176 - loss: 0.8404 - val_accuracy: 0.3889 - val_loss: 0.7444 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5000 - loss: 1.0013 - val_accuracy: 0.3889 - val_loss: 0.7346 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5000 - loss: 0.9278 - val_accuracy: 0.5000 - val_loss: 0.7270 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5147 - loss: 0.9604 - val_accuracy: 0.5556 - val_loss: 0.8120 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.4559 - loss: 0.9776 - val_accuracy: 0.5556 - val_loss: 0.7846 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.5735 - loss: 0.8909 - val_accuracy: 0.5556 - val_loss: 0.7574 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.4412 - loss: 0.9631 - val_accuracy: 0.5556 - val_loss: 0.7385 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.4853 - loss: 0.9262 - val_accuracy: 0.5556 - val_loss: 0.7280 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4853 - loss: 1.0104 - val_accuracy: 0.3889 - val_loss: 0.8218 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.5147 - loss: 1.0686 - val_accuracy: 0.4444 - val_loss: 0.7882 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5735 - loss: 1.1313 - val_accuracy: 0.5000 - val_loss: 0.7567 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.5441 - loss: 0.8224 - val_accuracy: 0.5000 - val_loss: 0.7432 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5882 - loss: 0.9285 - val_accuracy: 0.3889 - val_loss: 0.7352 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4559 - loss: 1.0751 - val_accuracy: 0.6111 - val_loss: 0.6582 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.4853 - loss: 0.9204 - val_accuracy: 0.6111 - val_loss: 0.6528 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.4559 - loss: 0.9871 - val_accuracy: 0.6111 - val_loss: 0.6533 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5294 - loss: 0.8317 - val_accuracy: 0.6111 - val_loss: 0.6559 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5294 - loss: 0.8202 - val_accuracy: 0.5556 - val_loss: 0.6584 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4559 - loss: 0.9617 - val_accuracy: 0.5556 - val_loss: 0.7215 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.4706 - loss: 0.9623 - val_accuracy: 0.5000 - val_loss: 0.6979 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5735 - loss: 0.7839 - val_accuracy: 0.3333 - val_loss: 0.6940 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6912 - loss: 0.6963 - val_accuracy: 0.3889 - val_loss: 0.6980 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5441 - loss: 0.8656 - val_accuracy: 0.5000 - val_loss: 0.7009 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5441 - loss: 0.8328 - val_accuracy: 0.4444 - val_loss: 0.8677 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.6912 - loss: 0.8799 - val_accuracy: 0.5000 - val_loss: 0.8266 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5441 - loss: 0.8886 - val_accuracy: 0.5000 - val_loss: 0.7991 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5000 - loss: 0.9118 - val_accuracy: 0.5000 - val_loss: 0.7842 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5588 - loss: 0.8512 - val_accuracy: 0.5000 - val_loss: 0.7716 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.3824 - loss: 1.2563 - val_accuracy: 0.4444 - val_loss: 0.9458 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6618 - loss: 0.8619 - val_accuracy: 0.4444 - val_loss: 0.9300 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5882 - loss: 0.8665 - val_accuracy: 0.4444 - val_loss: 0.9221 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - accuracy: 0.5735 - loss: 0.9234 - val_accuracy: 0.5000 - val_loss: 0.9171 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6029 - loss: 0.8915 - val_accuracy: 0.5000 - val_loss: 0.9082 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - accuracy: 0.6029 - loss: 0.8293 - val_accuracy: 0.5000 - val_loss: 0.7310 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.6765 - loss: 0.7976 - val_accuracy: 0.5000 - val_loss: 0.7230 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5588 - loss: 0.9604 - val_accuracy: 0.5000 - val_loss: 0.7265 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5735 - loss: 0.9435 - val_accuracy: 0.5000 - val_loss: 0.7237 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.6176 - loss: 0.8197 - val_accuracy: 0.5000 - val_loss: 0.7179 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.6324 - loss: 0.7708 - val_accuracy: 0.4444 - val_loss: 1.0596 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.5294 - loss: 0.9043 - val_accuracy: 0.4444 - val_loss: 1.0334 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - accuracy: 0.6176 - loss: 0.6243 - val_accuracy: 0.4444 - val_loss: 1.0070 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.5147 - loss: 0.9830 - val_accuracy: 0.4444 - val_loss: 0.9882 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6029 - loss: 0.8897 - val_accuracy: 0.4444 - val_loss: 0.9747 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5441 - loss: 0.8472 - val_accuracy: 0.5000 - val_loss: 0.7886 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.5147 - loss: 0.8877 - val_accuracy: 0.5000 - val_loss: 0.7982 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5441 - loss: 0.9909 - val_accuracy: 0.5000 - val_loss: 0.8036 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step - accuracy: 0.6471 - loss: 0.6643 - val_accuracy: 0.4444 - val_loss: 0.7937 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.5588 - loss: 0.7806 - val_accuracy: 0.4444 - val_loss: 0.7934 - learning_rate: 1.0000e-04
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4559 - loss: 1.1587 - val_accuracy: 0.6111 - val_loss: 0.6594 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.5735 - loss: 0.7949 - val_accuracy: 0.6667 - val_loss: 0.6617 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.4118 - loss: 1.1420 - val_accuracy: 0.6667 - val_loss: 0.6457 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5000 - loss: 0.8893 - val_accuracy: 0.6111 - val_loss: 0.6399 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5882 - loss: 0.8484 - val_accuracy: 0.6111 - val_loss: 0.6390 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4853 - loss: 1.0533 - val_accuracy: 0.5000 - val_loss: 0.8102 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.5882 - loss: 0.8991 - val_accuracy: 0.5000 - val_loss: 0.8040 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5441 - loss: 0.9137 - val_accuracy: 0.5000 - val_loss: 0.7945 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.6471 - loss: 0.7361 - val_accuracy: 0.5000 - val_loss: 0.7882 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5441 - loss: 0.9833 - val_accuracy: 0.4444 - val_loss: 0.7817 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5441 - loss: 0.8432 - val_accuracy: 0.5000 - val_loss: 0.7256 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6912 - loss: 0.7854 - val_accuracy: 0.4444 - val_loss: 0.7706 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - accuracy: 0.6029 - loss: 0.7408 - val_accuracy: 0.4444 - val_loss: 0.8287 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.5441 - loss: 0.8321 - val_accuracy: 0.4444 - val_loss: 0.8919 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.5735 - loss: 0.7952 - val_accuracy: 0.4444 - val_loss: 0.8924 - learning_rate: 1.0000e-04
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.5294 - loss: 1.0812 - val_accuracy: 0.5556 - val_loss: 0.7770 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.4412 - loss: 0.9909 - val_accuracy: 0.5556 - val_loss: 0.7758 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 0.4706 - loss: 0.9483 - val_accuracy: 0.5556 - val_loss: 0.7766 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5735 - loss: 0.8525 - val_accuracy: 0.5556 - val_loss: 0.7637 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - accuracy: 0.5882 - loss: 1.0030 - val_accuracy: 0.5556 - val_loss: 0.7581 - learning_rate: 0.0010
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[

  0%|          | 0/15 [00:00<?, ?it/s]

Adding metadata with 3 columns
Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step - accuracy: 0.4853 - loss: 1.0150 - val_accuracy: 0.5556 - val_loss: 0.7186 - learning_rate: 0.0010
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step - accuracy: 0.4853 - loss: 0.9907 - val_accuracy: 0.5556 - val_loss: 0.7229 - learning_rate: 0.0010
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.6324 - loss: 0.7661 - val_accuracy: 0.5556 - val_loss: 0.7289 - learning_rate: 0.0010
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.5588 - loss: 1.0294 - val_accuracy: 0.5556 - val_loss: 0.7302 - learning_rate: 0.0010
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.5882 - loss: 0.9227 - val_accuracy: 0.5556 - val_loss: 0.7318 - learning_rate: 1.0000e-04
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━

In [30]:
# Save subject_results
with open(f'subject_results-{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}.pickle', 'wb') as f:
    pickle.dump(subject_results, f, protocol=pickle.HIGHEST_PROTOCOL)

In [45]:
subject_results_val_acc = { subject: subject_results[subject].user_attrs["trial_data"]["val_accuracy"][-1] for subject in fat_dataset.subject_list }
subject_results_test_acc = { subject: subject_results[subject].user_attrs["trial_data"]["test_accuracy"] for subject in fat_dataset.subject_list }

subject_results_val_acc_df = pd.DataFrame(subject_results_val_acc, index=[0], columns=fat_dataset.subject_list).T.rename(columns={0: 'val_acc'}).sort_values(by='val_acc', ascending=False)
subject_results_test_acc_df = pd.DataFrame(subject_results_test_acc, index=[0], columns=fat_dataset.subject_list).T.rename(columns={0: 'test_acc'}).sort_values(by='test_acc', ascending=False)

subject_results_acc_df = subject_results_val_acc_df.join(subject_results_test_acc_df)

subject_results_acc_df

Unnamed: 0,val_acc,test_acc
17,0.888889,0.5
29,0.833333,0.590909
12,0.777778,0.409091
5,0.777778,0.545455
6,0.777778,0.454545
28,0.777778,0.590909
24,0.777778,0.454545
31,0.722222,0.545455
25,0.722222,0.5
23,0.722222,0.5
