## Imports

In [None]:
import os
import h5py
import numpy as np
from pathlib import Path

import matplotlib.pyplot as plt

from keras.models import Sequential # type: ignore
from keras.layers import Conv1D, LSTM, Dropout, BatchNormalization, Dense, GlobalAveragePooling1D, TimeDistributed # type: ignore
from keras.optimizers import Nadam # type: ignore
from keras.callbacks import EarlyStopping # type: ignore
from keras.utils import to_categorical # type: ignore
from keras.regularizers import l2 # type: ignore
from sklearn.utils import shuffle # type: ignore
from sklearn.model_selection import train_test_split # type: ignore

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

## Constants

In [None]:
DOWNSAMPLING_STEP = 2
NUM_CLASSES = 4
BATCH_SIZE = 400

ENCODE_MAP = {
        'rest': 0,
        'motor': 1,
        'memory': 2,
        'math': 3,
    }

## Loading and preprocessing

In [None]:
def get_dataset_name(filename_with_dir):
    filename_without_dir = str(filename_with_dir.name)
    temp = filename_without_dir.split('_')[:-1]
    dataset_name = '_'.join(temp)
    # chunk_n = filename_without_dir.split('_').split('.')[0]
    return dataset_name

def extract_label(filename, logs=False, encode_mapping=ENCODE_MAP):
    for key in encode_mapping:
        if key in filename:
            if logs: print(f"Mapping label {filename} to {encode_mapping[key]}")
            return encode_mapping[key]
    return encode_mapping['math']

def load_all_data(folder_path, logs=False, batch_size=100, step=4):
    folder_path = Path(folder_path)
    X, y = [], []

    for filename in os.listdir(folder_path):
        if filename.endswith('.h5'):
            print(f"Loading {filename}", end=" ") if logs else None
            file_path = folder_path / filename
            with h5py.File(file_path, 'r') as f:
                dataset_name = get_dataset_name(file_path)
                data = f.get(dataset_name)[()]  # shape (248, time)
                data = data.T  # (time, 248)
                if logs: print(f"Original: {data.shape}", end=" ")

                # --- Preprocess before batching ---
                # Normalize
                mean = np.mean(data, axis=0, keepdims=True)
                std = np.std(data, axis=0, keepdims=True) + 1e-8
                data = (data - mean) / std

                # Downsample time dimension
                data = data[::step, :]  # (downsampled_time, 248)
                if logs: print(f"Downsampled: {data.shape}", end=" ")

                label = extract_label(filename, logs=0)
                if logs: print(f"Label: {label}")

                # X.append(data)  # shape (downsampled_time, 248)
                # y.append(label)

                num_batches = len(data) // batch_size
                data_batches = np.split(data[:num_batches * batch_size], num_batches)
                X += data_batches
                y += [label] * num_batches

    X, y = np.array(X), np.array(y)
    print(f"Final shape of X: {X.shape}, y: {y.shape}")
    return X, y

In [None]:
X_train, y_train = load_all_data("Intra/train", logs=0, batch_size=BATCH_SIZE, step=DOWNSAMPLING_STEP)
print("Shape after preprocessing:", X_train.shape)
X_train, y_train = shuffle(X_train, y_train, random_state=42)
print("Shape after shuffling:", X_train.shape)

## LSTM Model only

In [None]:
time_steps = X_train.shape[1]
num_channels = X_train.shape[2]

if y_train.ndim == 1:
    y_train = to_categorical(y_train, num_classes=NUM_CLASSES)

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True
)

# Define LSTM model
lstm_model = Sequential([
    TimeDistributed(Dense(128, kernel_regularizer=l2(1e-4)), input_shape=(time_steps, num_channels)),
    
    LSTM(128, return_sequences=True, kernel_regularizer=l2(1e-4)),
    BatchNormalization(),
    Dropout(0.5),

    LSTM(64, kernel_regularizer=l2(1e-3)),
    BatchNormalization(),
    Dropout(0.5),

    Dense(64, activation='relu', kernel_regularizer=l2(1e-3)),
    Dropout(0.5),

    Dense(NUM_CLASSES, activation='softmax')
])

lstm_model.compile(optimizer=Nadam(learning_rate=0.0001),
                   loss='categorical_crossentropy',
                   metrics=['accuracy'])


In [None]:
history = lstm_model.fit(
    X_train, y_train,   
    epochs=5,
    batch_size=20,
    validation_split=0.5,
    callbacks=[early_stop],
    verbose=1
)

## Second model

In [None]:
# One-hot encode if needed
if y_train.ndim == 1:
    y_train = to_categorical(y_train, num_classes=NUM_CLASSES)

# Split manually (50% validation)
X_train_split, X_val_split, y_train_split, y_val_split = train_test_split(
    X_train, y_train, test_size=0.2, random_state=42, stratify=y_train.argmax(axis=1))

# Model definition
time_steps = X_train.shape[1]
num_channels = X_train.shape[2]

In [None]:
import numpy as np
print("Train labels distribution:", np.bincount(y_train_split.argmax(axis=1)))
print("Validation labels distribution:", np.bincount(y_val_split.argmax(axis=1)))


In [None]:
lstm_model = Sequential([
    Conv1D(64, kernel_size=1, activation='linear', input_shape=(time_steps, num_channels)),
    
    LSTM(128, return_sequences=True, kernel_regularizer=l2(1e-4)),
    BatchNormalization(),
    Dropout(0.4),

    LSTM(64, kernel_regularizer=l2(1e-3)),
    BatchNormalization(),
    Dropout(0.6),

    Dense(32, activation='relu', kernel_regularizer=l2(1e-3)),
    Dropout(0.4),

    Dense(NUM_CLASSES, activation='softmax')
])

lstm_model = Sequential([
    LSTM(128, return_sequences=True, input_shape=(time_steps, num_channels)),
    Dropout(0.3),
    
    LSTM(64),
    Dropout(0.3),
    
    Dense(NUM_CLASSES, activation='softmax')
])


lstm_model.compile(
    optimizer=Nadam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)

history = lstm_model.fit(
    X_train_split, y_train_split,
    validation_data=(X_val_split, y_val_split),
    epochs=15,
    batch_size=32,
    callbacks=[early_stop],
    verbose=1
)

## Testing and evaluating part

In [None]:
# Load and preprocess test data
X_test, y_test = load_all_data("Data/Intra/test", logs=0)
print(f"X_test shape before preprocessing: {np.array(X_test).shape}")

X_test = z_score_normalize(X_test) # TODO DOWNSAMPLE

X_test = np.array(X_test)
print(f"X_test shape after preprocessing: {np.array(X_test).shape}")

# One-hot encode labels
y_test = to_categorical(y_test, num_classes=4)

In [None]:
# Evaluate
test_loss, test_acc = lstm_model.evaluate(X_test, y_test, verbose=1)
print("Test accuracy:", test_acc)

In [None]:
# Plot accuracy
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim((0,1))
plt.legend()

# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# Predict class probabilities
y_pred_probs = lstm_model.predict(X_test)
# Convert one-hot back to class labels
y_pred = np.argmax(y_pred_probs, axis=1)
y_true = np.argmax(y_test, axis=1)

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["rest", "motor", "memory", "math"])
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()


# Cross

In [None]:
# Load and preprocess test data
X_test_cross, y_test_cross = load_all_data("Data/Cross/test1/", logs=0)
print(f"X_test shape before preprocessing: {np.array(X_test_cross).shape}")

X_test_cross = z_score_normalize(X_test_cross) # TODO DOWNSAMPLE

X_test_cross = np.array(X_test_cross)
print(f"X_test_cross shape after preprocessing: {np.array(X_test_cross).shape}")

# One-hot encode labels
y_test_cross = to_categorical(y_test_cross, num_classes=4)

In [None]:
# Evaluate
test_loss_cross, test_acc_cross = lstm_model.evaluate(X_test_cross, y_test_cross, verbose=1)
print("Test accuracy:", test_acc_cross)