In [25]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [43]:
import os
import h5py
import numpy as np
from collections import Counter
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import layers, models
import tensorflow as tf

In [44]:
def get_dataset_name(filename_with_dir):
    filename_without_dir = filename_with_dir.split('/')[-1]
    temp = filename_without_dir.split('_')[:-1]
    dataset_name = '_'.join(temp)
    return dataset_name

def get_label_from_filename(filename):
    base = os.path.basename(filename)
    if base.startswith('rest'):
        return 'rest'
    elif base.startswith('task_motor'):
        return 'motor'
    elif base.startswith('task_story_math'):
        return 'math'
    elif base.startswith('task_working_memory'):
        return 'memory'
    else:
        return 'unknown'

def load_all_data(data_folder):
    X, y = [], []
    files = sorted([os.path.join(data_folder, f) for f in os.listdir(data_folder) if f.endswith('.h5')])
    for file in files:
        with h5py.File(file, 'r') as f:
            dataset_name = get_dataset_name(file)
            matrix = f.get(dataset_name)[()]
            X.append(matrix)
            y.append(get_label_from_filename(file))
    return np.array(X), np.array(y)

data_folder = '/content/drive/MyDrive/DeepLearning Project/Final Project data/Cross/train'  # or Intra/train
X, y = load_all_data(data_folder)
print(X.shape, y.shape)
print("Labels:", np.unique(y))

(64, 248, 35624) (64,)
Labels: ['math' 'memory' 'motor' 'rest']


In [66]:
from scipy.signal import butter, lfilter

def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def bandpass_filter(data, lowcut=1.0, highcut=40.0, fs=1000.0, order=5):
    # data: (channels, time)
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    return lfilter(b, a, data, axis=-1)

def preprocess_data(X, downsample_factor=70, method='zscore', bandpass=False, log_transform=False, baseline_correct=False):
    """
    Preprocess X using downsampling and different normalization methods.
    New: bandpass filter, log-transform, baseline correction options.
    """
    X_downsampled = X[:, :, ::downsample_factor]

    # Apply bandpass filter if needed
    if bandpass:
        # Assume sampling rate is 1000Hz, change if your data is different
        X_downsampled = np.array([bandpass_filter(sample, lowcut=1, highcut=40, fs=1000, order=3) for sample in X_downsampled])

    # Log-transform if requested
    if log_transform:
        X_downsampled = np.sign(X_downsampled) * np.log1p(np.abs(X_downsampled))

    # Baseline correction (remove mean over time, per channel)
    if baseline_correct:
        X_downsampled = X_downsampled - np.mean(X_downsampled, axis=2, keepdims=True)

    # Choose normalization
    X_norm = []

    if method == 'zscore':
        for sample in X_downsampled:
            sample_norm = (sample - np.mean(sample, axis=1, keepdims=True)) / (np.std(sample, axis=1, keepdims=True) + 1e-7)
            X_norm.append(sample_norm)
        X_norm = np.array(X_norm)
    elif method == 'minmax':
        for sample in X_downsampled:
            min_ = np.min(sample, axis=1, keepdims=True)
            max_ = np.max(sample, axis=1, keepdims=True)
            sample_norm = (sample - min_) / (max_ - min_ + 1e-7)
            X_norm.append(sample_norm)
        X_norm = np.array(X_norm)
    elif method == 'global_zscore':
        mean = np.mean(X_downsampled)
        std = np.std(X_downsampled)
        X_norm = (X_downsampled - mean) / (std + 1e-7)
    elif method == 'raw':
        X_norm = X_downsampled
    else:
        raise ValueError("Unknown method: %s" % method)

    return X_norm

# EXAMPLES of usage:
# Standard z-score, with bandpass filter and baseline correction:
X_prep = preprocess_data(X, downsample_factor=70, method='zscore', bandpass=True, baseline_correct=True)
print(X_prep.shape)

# Try other combos:
# X_prep = preprocess_data(X, downsample_factor=35, method='minmax', log_transform=True)
# X_prep = preprocess_data(X, downsample_factor=70, method='raw', baseline_correct=True)

(64, 248, 509)


In [67]:
X_ready = np.transpose(X_prep, (0, 2, 1))
print(X_ready.shape)  # (samples, time, channels)

(64, 509, 248)


In [68]:
le = LabelEncoder()
y_int = le.fit_transform(y)
print(Counter(y_int))  # Optional: check class balance

Counter({np.int64(3): 16, np.int64(2): 16, np.int64(0): 16, np.int64(1): 16})


In [69]:
def sliding_window_augment(X, y, crop_len=128, step=32):
    X_aug, y_aug = [], []
    for xi, yi in zip(X, y):
        for start in range(0, X.shape[1] - crop_len + 1, step):
            crop = xi[start:start+crop_len, :]
            X_aug.append(crop)
            y_aug.append(yi)
    return np.array(X_aug), np.array(y_aug)

X_aug, y_aug = sliding_window_augment(X_ready, y_int, crop_len=128, step=32)
print(X_aug.shape, y_aug.shape)

(768, 128, 248) (768,)


In [70]:
X_train, X_val, y_train, y_val = train_test_split(
    X_aug, y_aug, test_size=0.2, stratify=y_aug, random_state=42
)
print(X_train.shape, X_val.shape)

(614, 128, 248) (154, 128, 248)


In [71]:
from tensorflow.keras.layers import Bidirectional

model = Sequential([
    Bidirectional(LSTM(32, return_sequences=True), input_shape=input_shape),
    Dropout(0.5),
    LSTM(16),
    Dropout(0.5),
    Dense(16, activation='relu'),
    Dropout(0.5),
    Dense(4, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()

  super().__init__(**kwargs)


In [72]:
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=30,
    batch_size=32
)

Epoch 1/30
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 133ms/step - accuracy: 0.2649 - loss: 1.3862 - val_accuracy: 0.2468 - val_loss: 1.3865
Epoch 2/30
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 97ms/step - accuracy: 0.2764 - loss: 1.3844 - val_accuracy: 0.2468 - val_loss: 1.3863
Epoch 3/30
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 102ms/step - accuracy: 0.2518 - loss: 1.3868 - val_accuracy: 0.2532 - val_loss: 1.3862
Epoch 4/30
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 99ms/step - accuracy: 0.2465 - loss: 1.3856 - val_accuracy: 0.2468 - val_loss: 1.3863
Epoch 5/30
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 107ms/step - accuracy: 0.2439 - loss: 1.3879 - val_accuracy: 0.2532 - val_loss: 1.3863
Epoch 6/30
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 105ms/step - accuracy: 0.2743 - loss: 1.3851 - val_accuracy: 0.2532 - val_loss: 1.3862
Epoch 7/30
[1m20/20[0m [32m