In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
from CogBeaconDataset import CogBeaconDataset

cogbeacon_root_path = '/Users/athenasaghi/VSProjects/CognitiveFatigueDetection/CogFatigueData/CogBeacon/'
dataset = CogBeaconDataset(cogbeacon_root_path)

In [4]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten, Concatenate, BatchNormalization, Dropout, MultiHeadAttention, LayerNormalization, Add
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle, resample
from sklearn.model_selection import train_test_split

# Load Data
df = pd.read_pickle('processed_data_full.pkl')
X_raweeg = np.stack(df['raweeg'].values)
X_features = np.stack(df['features'].values)
Y_labels = df['label'].values

# Remove class 3
mask = Y_labels != 3
X_raweeg = X_raweeg[mask]
X_features = X_features[mask]
Y_labels = Y_labels[mask]

num_classes = len(np.unique(Y_labels))
Y_labels = to_categorical(Y_labels, num_classes=num_classes)

# Split data first
X_train_raweeg, X_test_raweeg, X_train_features, X_test_features, Y_train, Y_test = train_test_split(
    X_raweeg, X_features, Y_labels, test_size=0.1, random_state=42, stratify=np.argmax(Y_labels, axis=1)
)

# Standardize Data (only on training set)
scaler = StandardScaler()
X_train_raweeg = scaler.fit_transform(X_train_raweeg.reshape(-1, X_train_raweeg.shape[-1])).reshape(X_train_raweeg.shape)
X_test_raweeg = scaler.transform(X_test_raweeg.reshape(-1, X_test_raweeg.shape[-1])).reshape(X_test_raweeg.shape)
X_train_features = scaler.fit_transform(X_train_features.reshape(-1, X_train_features.shape[-1])).reshape(X_train_features.shape)
X_test_features = scaler.transform(X_test_features.reshape(-1, X_test_features.shape[-1])).reshape(X_test_features.shape)

# Oversample Minority Classes with Gaussian Noise
X_resampled, Y_resampled, Xf_resampled = [], [], []
labels = np.argmax(Y_train, axis=1)
for cls in np.unique(labels):
    cls_indices = np.where(labels == cls)[0]
    X_cls_raweeg, X_cls_features, Y_cls = resample(
        X_train_raweeg[cls_indices], X_train_features[cls_indices], Y_train[cls_indices],
        n_samples=max([len(np.where(labels == c)[0]) for c in np.unique(labels)]),
        random_state=42
    )
    # Add small noise to avoid overfitting to duplicates
    X_cls_raweeg += np.random.normal(0, 0.01, X_cls_raweeg.shape)
    X_cls_features += np.random.normal(0, 0.01, X_cls_features.shape)
    X_resampled.append(X_cls_raweeg)
    Y_resampled.append(Y_cls)
    Xf_resampled.append(X_cls_features)

X_train_raweeg = np.vstack(X_resampled)
Y_train = np.vstack(Y_resampled)
X_train_features = np.vstack(Xf_resampled)

# Shuffle Dataset
X_train_raweeg, X_train_features, Y_train = shuffle(X_train_raweeg, X_train_features, Y_train, random_state=42)

# Define CLIP-style contrastive loss
def contrastive_loss(z_raweeg, z_features, temperature=0.07):
    z_raweeg = tf.math.l2_normalize(z_raweeg, axis=1)
    z_features = tf.math.l2_normalize(z_features, axis=1)
    logits = tf.matmul(z_raweeg, z_features, transpose_b=True) / temperature
    labels = tf.range(tf.shape(logits)[0])
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

# Pretrain using contrastive learning
encoder_raweeg = tf.keras.models.load_model('encoder_raweeg_clip.h5', compile=False)
encoder_features = tf.keras.models.load_model('encoder_features_clip.h5', compile=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)
early_stopping = EarlyStopping(monitor='loss', patience=5, restore_best_weights=True)

for epoch in range(50):
    idx = np.random.permutation(len(X_train_raweeg))
    X_train_raweeg_shuffled, X_train_features_shuffled = X_train_raweeg[idx], X_train_features[idx]
    losses = []
    for i in range(0, len(X_train_raweeg), 64):
        x_r_batch = X_train_raweeg_shuffled[i:i+64]
        x_f_batch = X_train_features_shuffled[i:i+64]
        if x_r_batch.shape[0] != 64:
            continue
        with tf.GradientTape() as tape:
            z_r = encoder_raweeg(x_r_batch, training=True)
            z_f = encoder_features(x_f_batch, training=True)
            loss = contrastive_loss(z_r, z_f)
        grads = tape.gradient(loss, encoder_raweeg.trainable_variables + encoder_features.trainable_variables)
        optimizer.apply_gradients(zip(grads, encoder_raweeg.trainable_variables + encoder_features.trainable_variables))
        losses.append(loss.numpy())
    print(f"Epoch {epoch+1}/50, Contrastive Loss: {np.mean(losses):.4f}")

# Save trained encoders
encoder_raweeg.save('encoder_raweeg_clip_2.h5')
encoder_features.save('encoder_features_clip_2.h5')

# Continue with classification using pretrained encoders
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for train_index, val_index in skf.split(X_train_raweeg, np.argmax(Y_train, axis=1)):
    X_train_fold_raweeg, X_val_raweeg = X_train_raweeg[train_index], X_train_raweeg[val_index]
    X_train_fold_features, X_val_features = X_train_features[train_index], X_train_features[val_index]
    Y_train_fold, Y_val = Y_train[train_index], Y_train[val_index]
    
    input_raweeg = Input(shape=(X_train_raweeg.shape[1], X_train_raweeg.shape[2]))
    input_features = Input(shape=(X_train_features.shape[1], X_train_features.shape[2]))
    z_raweeg = encoder_raweeg(input_raweeg)
    z_features = encoder_features(input_features)
    combined = Concatenate()([z_raweeg, z_features])
    output = Dense(num_classes, activation='softmax', kernel_regularizer=l2(1e-5))(combined)
    classifier = Model(inputs=[input_raweeg, input_features], outputs=output)
    classifier.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy', 'AUC'])
    classifier.fit([X_train_fold_raweeg, X_train_fold_features], Y_train_fold, validation_data=([X_val_raweeg, X_val_features], Y_val), epochs=30, batch_size=64, callbacks=[early_stopping])

classifier.save('classifier_clip.h5')


Epoch 1/50, Contrastive Loss: 1.0696
Epoch 2/50, Contrastive Loss: 0.7878
Epoch 3/50, Contrastive Loss: 0.6987
Epoch 4/50, Contrastive Loss: 0.7251
Epoch 5/50, Contrastive Loss: 0.6157
Epoch 6/50, Contrastive Loss: 0.6419
Epoch 7/50, Contrastive Loss: 0.6499
Epoch 8/50, Contrastive Loss: 0.6635
Epoch 9/50, Contrastive Loss: 0.6845
Epoch 10/50, Contrastive Loss: 0.6212
Epoch 11/50, Contrastive Loss: 0.6332


KeyboardInterrupt: 