In [3]:
import os
import random
import pickle
import re
from pathlib import Path

import numpy as np
import pandas as pd
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.signal import butter, filtfilt
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader as DL
from torch.utils.data import TensorDataset as TData
from tqdm import tqdm
import re
from sklearn.model_selection import train_test_split as tts
import pickle
import braindecode as bd


In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cpu')

In [5]:
def getAllPickles(directory="LHNT EEG"):
    """
    Searches through the directory and all its subfolders,
    returning a list of all .pkl file paths.
    """
    folders = [
        drctry for drctry in os.listdir(directory)
        if os.path.isdir(os.path.join(directory, drctry))
    ]
    files = []
    for folder in folders:
        folder_files = os.listdir(os.path.join(directory, folder))
        for file in folder_files:
            if ".pkl" in file:
                files.append(os.path.join(directory, folder, file))
    return files

def npFromPickle(pickle_files):
    """
    Loads NumPy arrays and labels from a list of pickle files.
    Label: 0 for 'left', 1 for 'right'.
    """
    np_data = []
    labels = []  # 0 is left, 1 is right
    for file in pickle_files:
        with open(file, "rb") as f:
            data1 = pickle.load(f)
            np_data.append(data1[0])
        # infer label from filename
        if 'right' in file.split('/')[-1]:
            labels.append(1)
        else:
            labels.append(0)
    return np_data, labels

np_data, labels = npFromPickle(getAllPickles())
print(len(np_data), len(labels))

380 380


In [6]:
def bandpass_filter(signal, crit_freq=[1, 40], sampling_freq=125, plot=False, channel=0):
    """
    Butterworth bandpass filter. 
    """
    order = 4
    b, a = scipy.signal.butter(
        order, crit_freq, btype='bandpass', fs=sampling_freq
    )
    processed_signal = scipy.signal.filtfilt(b, a, signal, axis=1)

    if plot:
        plt.figure()
        plt.xlabel('Time')
        plt.ylabel(f'Normalized amplitude of channel {channel}')
        plt.title(f'{crit_freq[0]}-{crit_freq[1]}Hz bandpass filter')

        # Plot unfiltered
        signal_min = np.min(signal, axis=1, keepdims=True)
        signal_max = np.max(signal, axis=1, keepdims=True)
        normed_signal = (signal - signal_min) / (signal_max - signal_min)

        # Plot filtered
        filtered_min = np.min(processed_signal, axis=1, keepdims=True)
        filtered_max = np.max(processed_signal, axis=1, keepdims=True)
        normed_filt = (processed_signal - filtered_min) / (filtered_max - filtered_min)

        plt.plot(normed_signal[channel], label='Input')
        plt.plot(normed_filt[channel], label='Transformed')
        plt.legend()
        plt.show()

    return processed_signal

def channel_rearrangment(sig, channel_order):
    """
    Rearranges channels according to the given channel_order list.
    NOTE: Channels in channel_order are assumed to be 1-indexed,
          so we shift by 1 to make them 0-indexed.
    """
    channel_order = [ch - 1 for ch in channel_order]
    reindexed = np.zeros_like(sig)
    for i, ind in enumerate(channel_order):
        reindexed[i] = sig[ind]
    return reindexed
  
ordered_channels = [1, 9, 11, 3, 2, 12, 10, 4, 13, 5, 15, 7, 14, 16, 6, 8]

In [7]:
train_x, test_x, train_y, test_y = tts(np_data, labels, test_size=0.25, random_state=42)


In [8]:
fixed_length = 1750

def fix_length(sig, fixed_len=875):
    """
    Ensure each EEG signal has exactly fixed_len timepoints.
    - If sig is longer than fixed_len, crop it.
    - If sig is shorter than fixed_len, zero-pad it.
    """
    c, l = sig.shape  # c = number of channels, l = number of timepoints
    if l > fixed_len:
        # Crop if it's too long
        return sig[:, :fixed_len]
    elif l < fixed_len:
        # Zero-pad if it's too short
        pad_width = fixed_len - l
        padded = np.zeros((c, fixed_len))
        padded[:, :l] = sig
        return padded
    else:
        # Exactly fixed_len
        return sig


In [9]:
train_eeg = []
train_labels = []
test_eeg = []
test_labels = []

for sig, label in zip(train_x, train_y):
    # Exclude empty signals
    if sig.shape[1] == 0:
        continue
    # 1) Channel re-index
    reindexed_signal = channel_rearrangment(sig, ordered_channels)
    # 2) Filter
    filtered_sig = bandpass_filter(reindexed_signal, [5, 40], 125)
    # 3) Standard scaling
    normed_sig = (filtered_sig - np.mean(filtered_sig, axis=1, keepdims=True)) / \
                 np.std(filtered_sig, axis=1, keepdims=True)
    if np.isnan(normed_sig).any():
        continue

    # Crop or pad to 875 timepoints
    fixed_sig = fix_length(normed_sig, fixed_len=1750)

    train_eeg.append(fixed_sig)
    train_labels.append(label)

for sig, label in zip(test_x, test_y):
    if sig.shape[1] == 0:
        continue
    reindexed_signal = channel_rearrangment(sig, ordered_channels)
    filtered_sig = bandpass_filter(reindexed_signal, [5, 40], 125)
    normed_sig = (filtered_sig - np.mean(filtered_sig, axis=1, keepdims=True)) / \
                 np.std(filtered_sig, axis=1, keepdims=True)
    if np.isnan(normed_sig).any():
        continue

    fixed_sig = fix_length(normed_sig, fixed_len=1750)

    test_eeg.append(fixed_sig)
    test_labels.append(label)

import torch

# Dimensions: (num_samples, num_channels, num_timepoints)
train_eeg_tensor = torch.zeros(
    (len(train_eeg), train_eeg[0].shape[0], train_eeg[0].shape[1])
)
test_eeg_tensor = torch.zeros(
    (len(test_eeg), test_eeg[0].shape[0], test_eeg[0].shape[1])
)

for i in range(len(train_eeg)):
    train_eeg_tensor[i] = torch.from_numpy(train_eeg[i].copy())

for i in range(len(test_eeg)):
    test_eeg_tensor[i] = torch.from_numpy(test_eeg[i].copy())

# Create one-hot label tensors
train_label_tensor = torch.tensor(train_labels, dtype=torch.long)
test_label_tensor = torch.tensor(test_labels, dtype=torch.long)


In [10]:
from braindecode.models import ATCNet, EEGNetv4
from braindecode.util import set_random_seeds


set_random_seeds(seed=42, cuda=False)

n_channels = train_eeg_tensor.shape[1]
n_times = train_eeg_tensor.shape[2]
freq = 125
input_window_sample = n_times // freq
n_outputs = len(torch.unique(train_label_tensor))
atc_model = ATCNet(
    n_chans=n_channels, 
    n_outputs=n_outputs, 
    input_window_seconds=input_window_sample,  
    sfreq=freq,
    add_log_softmax=False
).to(device)

eeg_net = EEGNetv4(
    n_chans=n_channels,
    n_outputs=n_outputs,
    n_times=n_times,
    final_conv_length='auto'
).to(device)

print(f'n_channels:' {n_channels}, 'n_times:' {n_times}, 'n_outputs:' {n_outputs}, 'input_window_sample:'{input_window_sample})




16 1750 2 14


In [None]:
# using the Brain decode trainer API
from braindecode.classifier import EEGClassifier

# Define EEGClassifier with CrossEntropyLoss
clf = EEGClassifier(
    module=atc_model,
    criterion=nn.CrossEntropyLoss,  # Works directly with raw logits
    optimizer=optim.AdamW,  # Better optimizer for EEG data
    lr=0.001,  # Learning rate
    max_epochs=50,  # Train for sufficient epochs
    batch_size=64,  # Optimal for EEG classification
    iterator_train__shuffle=True,  # Ensure dataset shuffling
    aggregate_predictions=False,  # Not using cropped mode
    device=device  # Use GPU if available
)

# Train the model
print(train_eeg_tensor.shape, train_label_tensor.shape)
clf.fit(train_eeg_tensor, train_label_tensor)

clf.module_.eval()

# Run inference on test data
accuracy = clf.score(test_eeg_tensor, test_label_tensor)
print(f"Test Accuracy: {accuracy:.4f}")



torch.Size([285, 16, 1750]) torch.Size([285])


  return F.conv2d(input, weight, bias, self.stride,


  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7163[0m       [32m0.4737[0m        [35m0.6951[0m  3.8139
      2        0.7231       [32m0.5263[0m        0.6955  3.3375
      3        [36m0.6941[0m       0.5263        0.6977  3.1446
      4        0.7013       0.5088        0.6967  3.3969
      5        0.7311       0.5088        0.6965  3.5345
      6        0.7227       0.5088        0.6955  3.4725
      7        0.7065       0.4386        0.6961  3.5602
      8        [36m0.6783[0m       0.4737        0.6986  3.3878
      9        0.7270       0.4912        0.6979  3.2681
     10        0.7240       0.5263        [35m0.6942[0m  3.3110
     11        0.7138       0.5088        [35m0.6923[0m  3.5166
     12        0.7102       0.4386        0.6923  3.4097
     13        0.7196       0.4561        0.6924  3.3606
     14        0.7218       0.4912        0.6978  3.5081
     15        0

In [2]:
torch.save(atc_model, 'atc_net')

NameError: name 'torch' is not defined

In [1]:
torch.save(atc_model, 'atc_net')

NameError: name 'torch' is not defined

In [76]:
# Using a traditional pytorch training loop [still working on this]

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.001)

best_val_loss = float("inf")
early_stopping_counter = 0
patience = 25  # Stop training if no improvement in 10 epochs
num_epochs = 100  # Maximum training epochs

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0

    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * X_batch.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == y_batch).sum().item()
        total += y_batch.size(0)

    train_loss = running_loss / total
    train_acc = correct / total

    # 🔹 Validation Loop
    model.eval()
    val_loss = 0.0
    val_correct, val_total = 0, 0

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            val_loss += loss.item() * X_batch.size(0)
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == y_batch).sum().item()
            val_total += y_batch.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total

    print(f"Epoch [{epoch+1}/{num_epochs}]: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # 🔹 Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
        torch.save(model.state_dict(), "best_model.pth")  # Save best model
    else:
        early_stopping_counter += 1

    if early_stopping_counter >= patience:
        print(f"Early stopping at epoch {epoch+1}")
        break



Epoch [1/100]: Train Loss: 0.3700, Train Acc: 0.8125, Val Loss: 0.7693, Val Acc: 0.4375
Epoch [2/100]: Train Loss: 0.2765, Train Acc: 0.9219, Val Loss: 0.7787, Val Acc: 0.5000
Epoch [3/100]: Train Loss: 0.3104, Train Acc: 0.9062, Val Loss: 0.7954, Val Acc: 0.5000
Epoch [4/100]: Train Loss: 0.2410, Train Acc: 0.9375, Val Loss: 0.8116, Val Acc: 0.5000
Epoch [5/100]: Train Loss: 0.2250, Train Acc: 0.9375, Val Loss: 0.8208, Val Acc: 0.5000
Epoch [6/100]: Train Loss: 0.2277, Train Acc: 0.9375, Val Loss: 0.8310, Val Acc: 0.5000
Epoch [7/100]: Train Loss: 0.2151, Train Acc: 0.9531, Val Loss: 0.8401, Val Acc: 0.5000
Epoch [8/100]: Train Loss: 0.2163, Train Acc: 0.9219, Val Loss: 0.8519, Val Acc: 0.5000
Epoch [9/100]: Train Loss: 0.1886, Train Acc: 0.9844, Val Loss: 0.8632, Val Acc: 0.5000
Epoch [10/100]: Train Loss: 0.1604, Train Acc: 0.9844, Val Loss: 0.8782, Val Acc: 0.5000
Epoch [11/100]: Train Loss: 0.1087, Train Acc: 1.0000, Val Loss: 0.8963, Val Acc: 0.5000
Early stopping at epoch 11


In [77]:
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

# 🔹 Final Test Accuracy
with torch.no_grad():
    X_test_tensor, y_test_tensor = X_test_tensor.to(device), y_test_tensor.to(device)
    outputs = model(X_test_tensor)
    _, predicted = torch.max(outputs, 1)
    test_acc = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)

print(f"Final Test Accuracy: {test_acc:.4f}")

Final Test Accuracy: 0.4375


  model.load_state_dict(torch.load("best_model.pth"))
