In [1]:
import os
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import mne
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

task_mapping = {
    3: "task1", 4: "task2", 5: "task3", 6: "task4",
    7: "task1", 8: "task2", 9: "task3", 10: "task4",
    11: "task1", 12: "task2", 13: "task3", 14: "task4"
}

In [3]:
def make_sliding_epochs_with_offset(raw, duration, overlap, offset_sec=0.0):
    raw_offset = raw.copy()
    raw_offset.crop(tmin=offset_sec, tmax=None)
    epochs = mne.make_fixed_length_epochs(
        raw_offset, duration=duration, overlap=overlap, preload=True, verbose=False
    )
    return epochs

def load_eeg_data(file_path):
    raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
    raw.pick(['Oz..', 'T7..', 'Cz..'])
    raw.filter(1., 40., fir_design='firwin', verbose=False)

    T_sec = 1
    stride_sec = 4 / 160
    offset_sec = 8 / 160
    overlap = T_sec - stride_sec

    epochs1 = make_sliding_epochs_with_offset(raw, duration=T_sec, overlap=overlap, offset_sec=0.0)
    epochs2 = make_sliding_epochs_with_offset(raw, duration=T_sec, overlap=overlap, offset_sec=offset_sec)

    return mne.concatenate_epochs([epochs1, epochs2])

In [None]:
root_dir = "./files/"

data_dict = {}

i = 0
for subject in sorted(os.listdir(root_dir)):
    subject_path = os.path.join(root_dir, subject)
    if os.path.isdir(subject_path) and re.match(r"S\d{3}", subject):
        
        # control the number of persons
        if(i < 5):
            i+=1
        else:
            break

        edf_files = sorted([f for f in os.listdir(subject_path) if f.endswith(".edf")])

        for edf_file in edf_files:
            match = re.match(r"(S\d{3})R(\d{2})\.edf", edf_file)
            if match:
                subject_id, session_id = match.groups()
                session_id = int(session_id)

                if session_id in task_mapping:
                    task = task_mapping[session_id]
                    full_path = os.path.join(subject_path, edf_file)

                    if subject_id not in data_dict:
                        data_dict[subject_id] = {task: []}
                    if task not in data_dict[subject_id]:
                        data_dict[subject_id][task] = []
                    
                    data_dict[subject_id][task].append(full_path)

In [14]:
train_files, test_files = [], []

for subject_id, tasks in data_dict.items():
    for task, file_list in tasks.items():
        if len(file_list) >= 3:
            train_files.extend(file_list[:2])
            test_files.append(file_list[2])

class EEGMotorImageryDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list
        self.data = []
        self.labels = []

        for file_path in self.file_list:
            epochs = load_eeg_data(file_path)
            self.data.append(epochs.get_data())
            self.labels.append(epochs.events[:, -1])

        self.data = np.concatenate(self.data, axis=0).astype(np.float32)
        self.labels = np.concatenate(self.labels, axis=0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        sample = sample[np.newaxis, :, :]
        return torch.tensor(sample), torch.tensor(label)

In [16]:
train_dataset = EEGMotorImageryDataset(train_files)
test_dataset = EEGMotorImageryDataset(test_files)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9920 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


Not setting metadata
9760 matching events found
No baseline correction applied


  return mne.concatenate_epochs([epochs1, epochs2])


In [19]:
class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 64), padding=(0, 32))
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(1, 32), groups=16)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool1 = nn.AvgPool2d((1, 8))
        self.dropout1 = nn.Dropout(0.5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=(1, 16), padding=(0, 8))
        self.bn3 = nn.BatchNorm2d(64)
        self.pool2 = nn.AvgPool2d((1, 4))
        self.flatten = nn.Flatten()
        self.fc = None

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.pool1(x)
        x = self.dropout1(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.pool2(x)
        x = self.flatten(x)

        if self.fc is None:
            feature_dim = x.shape[1]
            self.fc = nn.Linear(feature_dim, 4).to(x.device)

        x = self.fc(x)

        return x

In [21]:
model = EEGNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model(x)
        loss = loss_fn(outputs, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == y).sum().item()
            total += y.size(0)
    
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {running_loss / len(train_loader)} | Test Accuracy: {100 * correct / total:.2f}%")



KeyboardInterrupt: 