In [10]:
from typing import Iterator

import numpy as np
import h5py
import os
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
from scipy.signal import butter, filtfilt
from sklearn.metrics import classification_report

# --- Configuration ---
DATA_PATH = os.path.abspath("../extracted_zip_in_here/Final Project data/")
INTRA_TRAIN_FOLDER = os.path.join(DATA_PATH, "Intra/train")
INTRA_TEST_FOLDER = os.path.join(DATA_PATH, "Intra/test")
DOWNSAMPLE_FACTOR = 0.7

# --- File and Data Handling ---
def get_dataset_name(filepath: str) -> str:
    filename = os.path.basename(filepath)
    return "_".join(filename.split('.')[:-1][0].split('_')[:-1])

def extract_data_from_folder(folder_path: str, shuffle: bool = False) -> Iterator[tuple[str, np.ndarray]]:
    files = os.listdir(folder_path)
    if shuffle:
        np.random.shuffle(files)
    for file_name in files:
        with h5py.File(os.path.join(folder_path, file_name), 'r') as f:
            dataset_name = get_dataset_name(file_name)
            yield dataset_name, f[dataset_name][()].T  # transpose once here

def learn_minmax(folder_path: str) -> tuple[np.ndarray, np.ndarray]:
    min_val, max_val = None, None
    for _, data in extract_data_from_folder(folder_path):
        min_data, max_data = np.min(data, axis=0), np.max(data, axis=0)
        min_val = min_data if min_val is None else np.minimum(min_val, min_data)
        max_val = max_data if max_val is None else np.maximum(max_val, max_data)
    return min_val, max_val

def generate_label(name: str) -> np.ndarray:
    classes = ["rest", "task_motor", "task_story_math", "task_working_memory"]
    for i, cls in enumerate(classes):
        if cls + "_" in name:
            label = np.zeros(len(classes))
            label[i] = 1
            return label
    raise ValueError(f"Unknown file name: {name}")

# --- Preprocessing ---
def scale_data(data: np.ndarray, min_val: np.ndarray, max_val: np.ndarray) -> np.ndarray:
    return (data - min_val) / (max_val - min_val)

def downsample(data: np.ndarray, factor: float) -> np.ndarray:
    num_samples = int(len(data) * factor)
    indices = np.floor(np.arange(num_samples) * (len(data) / num_samples)).astype(int)
    return data[indices]

def add_gaussian_noise(data: np.ndarray, stddev: float = 0.01) -> np.ndarray:
    noise = np.random.normal(0, stddev, data.shape)
    return data + noise

def time_shift(data: np.ndarray, shift_max: int = 10) -> np.ndarray:
    shift = np.random.randint(-shift_max, shift_max)
    return np.roll(data, shift, axis=0)

def channel_dropout(data: np.ndarray, dropout_rate: float = 0.1) -> np.ndarray:
    num_channels = data.shape[1]
    mask = np.random.rand(num_channels) > dropout_rate
    return data * mask[np.newaxis, :]

def random_scaling(data: np.ndarray, scale_range=(0.9, 1.1)) -> np.ndarray:
    scale = np.random.uniform(*scale_range)
    return data * scale

def bandpass_filter(data: np.ndarray, lowcut=0.5, highcut=30.0, fs=250.0, order=3) -> np.ndarray:
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data, axis=0)

def zscore_per_channel(data: np.ndarray) -> np.ndarray:
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    std[std == 0] = 1  # Prevent division by zero
    return (data - mean) / std

def baseline_correction(data: np.ndarray, baseline_duration=500) -> np.ndarray:
    if data.shape[0] < baseline_duration:
        return data
    baseline_mean = np.mean(data[:baseline_duration], axis=0)
    return data - baseline_mean

# --- Batching ---
def create_batches(folder: str, batch_size: int, preprocessing: list = None, shuffle: bool = False) -> Iterator[tuple[np.ndarray, np.ndarray]]:
    batch_data, batch_labels = [], []
    for i, (name, data) in enumerate(extract_data_from_folder(folder, shuffle)):
        if preprocessing:
            for fn in preprocessing:
                data = fn(data)
        batch_data.append(data)
        batch_labels.append(generate_label(name))

        if (i + 1) % batch_size == 0:
            yield np.array(batch_data), np.array(batch_labels)
            batch_data, batch_labels = [], []
    if batch_data:
        yield np.array(batch_data), np.array(batch_labels)

# --- Model ---
def create_model(input_shape=(3562, 248)) -> Sequential:
    model = Sequential([
        LSTM(64, return_sequences=False, input_shape=input_shape),
        Dense(64, activation='relu'),
        Dense(4, activation='softmax')
    ])
    model.compile(
        loss=CategoricalCrossentropy(),
        optimizer=Adam(),
        metrics=['accuracy']
    )
    return model

def train_model(model, train_folder: str, epochs: int = 10, batch_size: int = 8, preprocessing: list = None):
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        for batch_X, batch_y in create_batches(train_folder, batch_size, preprocessing, shuffle=False): # TODO was True
            indices = np.random.permutation(len(batch_X))
            model.fit(batch_X[indices], batch_y[indices], verbose=0)
    return model

def evaluate_model(model, folder: str, batch_size: int, preprocessing: list) -> tuple[float, float]:
    losses, accuracies = [], []
    for batch_X, batch_y in create_batches(folder, batch_size, preprocessing, shuffle=False):
        loss, acc = model.evaluate(batch_X, batch_y, verbose=0)
        losses.append(loss)
        accuracies.append(acc)
    return np.mean(losses), np.mean(accuracies)

def detailed_evaluation(model, folder, batch_size, preprocessing):
    y_true, y_pred = [], []
    for batch_X, batch_y in create_batches(folder, batch_size, preprocessing, shuffle=False):
        preds = model.predict(batch_X)
        y_true.extend(np.argmax(batch_y, axis=1))
        y_pred.extend(np.argmax(preds, axis=1))
    print(classification_report(y_true, y_pred))


In [None]:
# --- Run Training and Evaluation ---
min_val, max_val = learn_minmax(INTRA_TRAIN_FOLDER)
print(f"Learned min/max shapes: {min_val.shape}, {max_val.shape}")

preprocessing_pipeline = [
    # baseline_correction,
    # bandpass_filter,
    lambda x: scale_data(x, min_val, max_val),
    lambda x: downsample(x, DOWNSAMPLE_FACTOR),
    add_gaussian_noise, # These are all added for more general performance ()
    # time_shift,
    channel_dropout
    # random_scaling
]

model = create_model()
trained_model = train_model(model, INTRA_TRAIN_FOLDER, epochs=10, batch_size=8, preprocessing=preprocessing_pipeline)

train_loss, train_acc = evaluate_model(trained_model, INTRA_TRAIN_FOLDER, batch_size=8, preprocessing=preprocessing_pipeline)
print(f"Training Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")

test_loss, test_acc = evaluate_model(trained_model, INTRA_TEST_FOLDER, batch_size=8, preprocessing=preprocessing_pipeline)
print(f"Test Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}")

In [None]:
detailed_evaluation(trained_model, INTRA_TEST_FOLDER, batch_size=8, preprocessing=preprocessing_pipeline)

# Cross testing

In [8]:
# Cross configuration
CROSS_TRAIN_FOLDER = os.path.join(DATA_PATH, "Cross/train")
CROSS_TEST1_FOLDER = os.path.join(DATA_PATH, "Cross/test1")
CROSS_TEST2_FOLDER = os.path.join(DATA_PATH, "Cross/test2")
CROSS_TEST3_FOLDER = os.path.join(DATA_PATH, "Cross/test3")
DOWNSAMPLE_FACTOR_CROSS = 1.0

min_val_cross, max_val_cross = learn_minmax(CROSS_TRAIN_FOLDER)

preprocessing_pipeline_cross = [
    lambda x: scale_data(x, min_val_cross, max_val_cross),
    lambda x: downsample(x, DOWNSAMPLE_FACTOR_CROSS),
    add_gaussian_noise,
    channel_dropout
]

In [None]:
# Actual training
model_cross = create_model() # TODO change this to a more 'general' model
trained_model_cross = train_model(model_cross, CROSS_TRAIN_FOLDER, epochs = 10, batch_size = 8, 
                                  preprocessing = preprocessing_pipeline_cross)

In [13]:
print('Test1 folder')
detailed_evaluation(trained_model_cross, CROSS_TEST1_FOLDER, batch_size = 8, 
                    preprocessing = preprocessing_pipeline_cross)
print('Test2 folder')
detailed_evaluation(trained_model_cross, CROSS_TEST2_FOLDER, batch_size = 8, 
                    preprocessing = preprocessing_pipeline_cross)
print('Test3 folder')
detailed_evaluation(trained_model_cross, CROSS_TEST3_FOLDER, batch_size = 8, 
                    preprocessing = preprocessing_pipeline_cross)

Test1 folder
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
              precision    recall  f1-score   support

           0       1.00      1.00      1.00         4
           1       1.00      0.25      0.40         4
           2       0.44      1.00      0.62         4
           3       1.00      0.50      0.67         4

    accuracy                           0.69        16
   macro avg       0.86      0.69      0.67        16
weighted avg       0.86      0.69      0.67        16

Test2 folder
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step
              precision    recall  f1-score   support

           0       0.50      1.00      0.67         4
           1       0.75      0.75      0.75         4
           2       1.00      0.75      0.86         4
           3       0.00      0.00      0.0