In [1]:
import mne
import numpy as np
from scipy.signal import butter
import torch
import torch.nn.functional as F
import torch.nn as nn
import pandas as pd
from torchsummary import summary
import torchvision.transforms as transforms
from torch import optim
from torch.utils.data import DataLoader, TensorDataset, Dataset
from sklearn.metrics import confusion_matrix
import seaborn as sns
from torch.utils.data import random_split
import matplotlib.pyplot as plt
import scipy

# Load and view Train_Data

In [2]:
# data_path = ['A0'+str(i)+j for i in range(1,10) for j in ['T','E']]
data_path = ['A0'+str(i)+j for i in range(1,10) for j in ['T']]
raw = [mne.io.read_raw_gdf('2a/'+path+'.gdf',stim_channel="auto", verbose='ERROR',exclude=(["EOG-left", "EOG-central", "EOG-right"])) for path in data_path]
for i in range(len(raw)):
    raw[i].rename_channels({'EEG-Fz': 'Fz', 'EEG-0': 'FC3', 'EEG-1': 'FC1', 'EEG-2': 'FCz', 'EEG-3': 'FC2', 
                            'EEG-4': 'FC4','EEG-5': 'C5', 'EEG-C3': 'C3', 'EEG-6': 'C1', 'EEG-Cz': 'Cz', 
                            'EEG-7': 'C2', 'EEG-C4': 'C4', 'EEG-8': 'C6','EEG-9': 'CP3', 'EEG-10': 'CP1', 
                            'EEG-11': 'CPz', 'EEG-12': 'CP2', 'EEG-13': 'CP4','EEG-14': 'P1', 'EEG-15': 'Pz', 
                            'EEG-16': 'P2', 'EEG-Pz': 'POz'})


In [3]:
print(raw[3].get_data().shape)
events, event_id = mne.events_from_annotations(raw[3])
print(event_id)

(22, 600915)
Used Annotations descriptions: ['1023', '1072', '32766', '768', '769', '770', '771', '772']
{'1023': 1, '1072': 2, '32766': 3, '768': 4, '769': 5, '770': 6, '771': 7, '772': 8}


In [4]:
print(raw[0].info)
print(raw[3].info)

<Info | 8 non-empty values
 bads: []
 ch_names: Fz, FC3, FC1, FCz, FC2, FC4, C5, C3, C1, Cz, C2, C4, C6, CP3, ...
 chs: 22 EEG
 custom_ref_applied: False
 highpass: 0.5 Hz
 lowpass: 100.0 Hz
 meas_date: 2005-01-17 12:00:00 UTC
 nchan: 22
 projs: []
 sfreq: 250.0 Hz
 subject_info: 4 items (dict)
>
<Info | 8 non-empty values
 bads: []
 ch_names: Fz, FC3, FC1, FCz, FC2, FC4, C5, C3, C1, Cz, C2, C4, C6, CP3, ...
 chs: 22 EEG
 custom_ref_applied: False
 highpass: 0.5 Hz
 lowpass: 100.0 Hz
 meas_date: 2004-11-08 12:00:00 UTC
 nchan: 22
 projs: []
 sfreq: 250.0 Hz
 subject_info: 4 items (dict)
>


# Train_Data preprocessing

In [5]:
for i in range(len(raw)):
    print(raw[i].get_data().shape)
    event, eventid = mne.events_from_annotations(raw[i])
    count = np.sum((event[:,2] - 7) >= 0)
    print(eventid)

(22, 672528)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
{'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '769': 7, '770': 8, '771': 9, '772': 10}
(22, 677169)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
{'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '769': 7, '770': 8, '771': 9, '772': 10}
(22, 660530)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
{'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '769': 7, '770': 8, '771': 9, '772': 10}
(22, 600915)
Used Annotations descriptions: ['1023', '1072', '32766', '768', '769', '770', '771', '772']
{'1023': 1, '1072': 2, '32766': 3, '768': 4, '769': 5, '770': 6, '771': 7, '772': 8}
(22, 686120)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
{'1023': 1, '107

In [6]:
# specially for 4th sub
events = []
event_ids = []
for i in range(len(raw)):
    event_to_id = dict({'769':7,'770':8,'771':9,'772':10}) # Cue on left/right/foot/tongue
    if i==3:
        event_to_id = dict({'769':5,'770':6,'771':7,'772':8}) # Cue on left/right/foot/tongue
        event, _ = mne.events_from_annotations(raw[i])
        events.append(event)
        ids = np.unique(events[i][:, 2])
        event_id = {k: v for k, v in event_to_id.items() if v in ids}
        event_ids.append(event_id)
        print(event)
        print(event_id)
        raw[i].load_data()
        data = raw[i].get_data()
    else:
        event, _ = mne.events_from_annotations(raw[i])
        events.append(event)
        ids = np.unique(events[i][:, 2])
        event_id = {k: v for k, v in event_to_id.items() if v in ids}
        event_ids.append(event_id)
        print(event)
        print(event_id)
        raw[i].load_data()
        data = raw[i].get_data()
    for i_chan in range(data.shape[0]):  # go through 22 channel
        # set min to mean
        chan = data[i_chan]
        data[i_chan] = np.where(chan == np.min(chan), np.nan, chan)
        mask = np.isnan(data[i_chan])
        chan_mean = np.nanmean(data[i_chan])
        data[i_chan, mask] = chan_mean
    raw[i] = mne.io.RawArray(data, raw[i].info, verbose="ERROR")

Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
[[     0      0      5]
 [     0      0      3]
 [ 29683      0      5]
 ...
 [670550      0      6]
 [670550      0      1]
 [671050      0      7]]
{'769': 7, '770': 8, '771': 9, '772': 10}
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
[[     0      0      5]
 [     0      0      3]
 [ 31513      0      5]
 ...
 [673632      0      7]
 [675191      0      6]
 [675691      0      9]]
{'769': 7, '770': 8, '771': 9, '772': 10}
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
[[     0      0      5]
 [     0      0      3]
 [ 23720      0      5]
 ...
 [656993      0      7]
 [658552      0      6]
 [659052      0      9]]
{'769': 7, '770': 8, '771': 9, '77

## Plot Before filter

In [7]:
# for i in range(len(raw)):
#     raw[i].plot(duration=5,n_channels=22,clipping=None)

In [8]:
%matplotlib inline

In [9]:
# # # for i in range(len(raw)):
# #     # psd_plot = raw[i].compute_psd().plot(average=True)

# psd_plot = raw[0].compute_psd().plot(average=True)

## Filter

In [10]:

tmin,tmax=-0.5,4 # tmin-tmax s after the event

for i in range(len(raw)):
    # IIR filter
    iir_params = dict(order=6, ftype='butter')
    raw[i].filter(l_freq=7., h_freq=37., method='iir', iir_params=iir_params) # 7-37 / 3-50
    # raw[i].filter(7.,47.,fir_design='firwin')
    epochs = mne.Epochs(raw[i],events[i],event_ids[i],tmin,tmax,proj=True,baseline=None,preload=True)
    labelsfile = scipy.io.loadmat('2a_label/'+data_path[i]+'.mat')
    labels = labelsfile['classlabel'].reshape(288)
    epochs_data = epochs.get_data()
#     data = epochs_data[:,:,:-1]
#     BCI_IV_2a_data = np.save(data_path[i][:-5]+'.npy',data)
    np.savez('2a_pre/'+data_path[i]+'.npz', data=epochs_data[:,:,:-1], label=labels)
    print(labels)
    print(epochs_data[:,:,:-1].shape)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 7 - 37 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 24 (effective, after forward-backward)
- Cutoffs at 7.00, 37.00 Hz: -6.02, -6.02 dB

Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 288 events and 1126 original time points ...
0 bad epochs dropped
[4 3 2 1 1 2 3 4 2 3 1 1 1 4 2 2 1 1 3 1 2 4 4 3 1 4 4 2 4 4 2 1 2 3 3 3 4
 3 1 4 2 3 2 3 4 2 3 1 1 1 4 2 1 3 1 3 2 4 1 3 3 1 3 2 4 4 4 3 1 4 2 4 2 1
 3 2 1 3 3 1 3 4 4 2 1 2 4 2 4 3 2 2 2 3 4 1 2 4 1 3 3 4 1 1 3 2 4 4 4 2 1
 3 2 4 1 4 3 2 4 4 1 2 2 3 4 2 1 1 4 2 1 3 2 2 3 1 4 3 3 3 3 1 2 1 2 1 1 3
 3 2 3 4 1 4 1 1 2 4 3 2 4 3 4 3 4 2 2 4 1 2 2 2 3 4 1 4 1 3 1 4 1 3 1 2 3
 3 4 1 2 4 2 3 3 1 4 2 4 1 1 3 3 2 4 2 2 1 2 4 4 2 2 2 2 4 4 3 4 1 2 3 2 1
 4 1 4 1 1 1 1 3 3 4 2 3 3 3 4

In [11]:
labels.shape

(288,)

## Plot After filter

In [12]:
# for i in range(len(raw)):
#     raw[i].plot(duration=5,n_channels=22,clipping=None)

In [13]:
%matplotlib inline

In [14]:
# # # for i in range(len(raw)):
# #     # psd_plot = raw[i].compute_psd().plot(average=True)

# psd_plot = raw[0].compute_psd().plot(average=True)

# Load and view Val_Data

In [15]:
# data_path = ['A0'+str(i)+j for i in range(1,10) for j in ['T','E']]
data_path = ['A0'+str(i)+j for i in range(1,10) for j in ['E']]
raw = [mne.io.read_raw_gdf('2a/'+path+'.gdf',stim_channel="auto", verbose='ERROR',exclude=(["EOG-left", "EOG-central", "EOG-right"])) for path in data_path]
for i in range(len(raw)):
    raw[i].rename_channels({'EEG-Fz': 'Fz', 'EEG-0': 'FC3', 'EEG-1': 'FC1', 'EEG-2': 'FCz', 'EEG-3': 'FC2', 
                            'EEG-4': 'FC4','EEG-5': 'C5', 'EEG-C3': 'C3', 'EEG-6': 'C1', 'EEG-Cz': 'Cz', 
                            'EEG-7': 'C2', 'EEG-C4': 'C4', 'EEG-8': 'C6','EEG-9': 'CP3', 'EEG-10': 'CP1', 
                            'EEG-11': 'CPz', 'EEG-12': 'CP2', 'EEG-13': 'CP4','EEG-14': 'P1', 'EEG-15': 'Pz', 
                            'EEG-16': 'P2', 'EEG-Pz': 'POz'})
data_path

['A01E', 'A02E', 'A03E', 'A04E', 'A05E', 'A06E', 'A07E', 'A08E', 'A09E']

In [16]:
print(raw[1].get_data().shape)
events, event_id = mne.events_from_annotations(raw[1])
print(event_id)

(22, 662666)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
{'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '783': 7}


In [17]:
print(raw[0].info)
events, event_id

<Info | 8 non-empty values
 bads: []
 ch_names: Fz, FC3, FC1, FCz, FC2, FC4, C5, C3, C1, Cz, C2, C4, C6, CP3, ...
 chs: 22 EEG
 custom_ref_applied: False
 highpass: 0.5 Hz
 lowpass: 100.0 Hz
 meas_date: 2005-01-19 12:00:00 UTC
 nchan: 22
 projs: []
 sfreq: 250.0 Hz
 subject_info: 4 items (dict)
>


(array([[     0,      0,      5],
        [     0,      0,      3],
        [ 22901,      0,      5],
        ...,
        [659129,      0,      7],
        [660688,      0,      6],
        [661188,      0,      7]]),
 {'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '783': 7})

# Val_Data preprocessing

In [18]:
for i in range(len(raw)):
    print(raw[i].get_data().shape)
    event, eventid = mne.events_from_annotations(raw[i])
    count = np.sum((event[:,2] - 7) >= 0)
    print(eventid)
    print(count)

(22, 687000)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
{'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '783': 7}
288
(22, 662666)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
{'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '783': 7}
288
(22, 648775)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
{'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '783': 7}
288
(22, 660047)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
{'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '783': 7}
288
(22, 679863)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
{'1023': 1, '1072': 2, '276': 3, '277': 4, '32766': 5, '768': 6, '783': 7}
288
(22, 666373)
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
{'1023': 1, '1072

In [19]:
event_to_id = dict({'783':7}) # Cue on left/right/foot/tongue
events = []
event_ids = []
for i in range(len(raw)):
    event, _ = mne.events_from_annotations(raw[i])
    events.append(event)
    ids = np.unique(events[i][:, 2])
    event_id = {k: v for k, v in event_to_id.items() if v in ids}
    event_ids.append(event_id)
    print(event)
    print(event_id)
    raw[i].load_data()
    data = raw[i].get_data()
    for i_chan in range(data.shape[0]):  # go through 22 channel
        # set min to mean
        chan = data[i_chan]
        data[i_chan] = np.where(chan == np.min(chan), np.nan, chan)
        mask = np.isnan(data[i_chan])
        chan_mean = np.nanmean(data[i_chan])
        data[i_chan, mask] = chan_mean
    raw[i] = mne.io.RawArray(data, raw[i].info, verbose="ERROR")

Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
[[     0      0      5]
 [     0      0      3]
 [ 34291      0      5]
 ...
 [683463      0      7]
 [685022      0      6]
 [685522      0      7]]
{'783': 7}
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
[[     0      0      5]
 [     0      0      3]
 [ 22901      0      5]
 ...
 [659129      0      7]
 [660688      0      6]
 [661188      0      7]]
{'783': 7}
Reading 0 ... 662665  =      0.000 ...  2650.660 secs...
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
[[     0      0      5]
 [     0      0      3]
 [ 19884      0      5]
 ...
 [645238      0      7]
 [646797      0      6]
 [647297      0      7]]
{'783': 7}
Reading 0 ... 648774  =      0.000 ...  2595.096 secs...
Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']
[[    

## Filter

In [20]:
# # val_datas = np.load('2a_pre/'+data_path[1]+'.npz')
# # val_predata=torch.Tensor(val_datas['data'])
# # val_label=torch.Tensor(val_datas['label']).long()-1
# # print(val_predata.shape,val_label.shape)
# # print(label)
# # d = np.load('2a_pre/'+data_path[3]+'.npz')
# # p=torch.Tensor(d['data'])
# # l=torch.Tensor(d['label']).long()-1
# # print(p.shape,l.shape)
# # print(l)
# datas = []
# val_datas = []
# for sub in range(len(data_path)//2):
#     # Load and store training data
#     data = np.load('2a_pre/'+data_path[2*sub]+'.npz')
#     datas.append(data)
#     predata = torch.Tensor(datas[sub]['data'])
#     label = torch.Tensor(datas[sub]['label']).long() - 1
#     print(predata.shape, label.shape)
#     print(label)
    
#     # Load and store validation data
#     val_data = np.load('2a_pre/'+data_path[2*sub+1]+'.npz')
#     val_datas.append(val_data)
#     val_predata = torch.Tensor(val_datas[sub]['data'])
#     val_label = torch.Tensor(val_datas[sub]['label']).long() - 1
#     print(val_predata.shape, val_label.shape)
#     print(val_label)


In [21]:



tmin,tmax=-0.5,4 # tmin-tmax s after the event

for i in range(len(raw)):
    # # IIR filter
    iir_params = dict(order=6, ftype='butter')
    raw[i].filter(l_freq=7., h_freq=37., method='iir', iir_params=iir_params) # 7-37 / 3-50
    # raw[i].filter(7.,47.,fir_design='firwin')
    epochs = mne.Epochs(raw[i],events[i],event_ids[i],tmin,tmax,proj=True,baseline=None,preload=True)
    labelsfile = scipy.io.loadmat('2a_label/'+data_path[i]+'.mat')
    labels = labelsfile['classlabel'].reshape(288)
    epochs_data = epochs.get_data()
#     data = epochs_data[:,:,:-1]
#     BCI_IV_2a_data = np.save(data_path[i][:-5]+'.npy',data)
    np.savez('2a_pre/'+data_path[i]+'.npz', data=epochs_data[:,:,:-1], label=labels)
    print(labels)
    print(epochs_data[:,:,:-1].shape)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 7 - 37 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 24 (effective, after forward-backward)
- Cutoffs at 7.00, 37.00 Hz: -6.02, -6.02 dB

Not setting metadata
288 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 288 events and 1126 original time points ...
0 bad epochs dropped
[1 2 2 1 2 1 2 3 2 4 1 3 2 1 4 4 4 4 4 1 3 2 1 1 3 4 1 3 3 3 1 2 1 2 2 1 2
 3 2 3 3 4 3 3 4 4 4 4 4 3 2 1 1 2 3 4 2 3 1 1 1 4 2 2 1 1 3 1 2 4 4 3 1 4
 4 2 4 4 2 1 2 3 3 3 4 3 1 4 2 3 2 3 4 2 3 1 1 1 4 2 1 3 1 3 2 4 1 3 3 1 3
 2 4 4 4 3 1 4 2 4 2 1 3 2 1 3 3 1 3 4 4 2 1 2 4 2 4 3 2 2 2 3 4 1 2 4 1 3
 3 4 1 1 3 2 4 4 4 2 1 3 2 4 1 4 3 2 4 4 1 2 2 3 4 2 1 1 4 2 1 3 2 2 3 1 4
 3 3 3 3 1 2 1 2 1 1 3 3 2 3 4 1 4 1 1 2 4 3 2 4 3 4 3 4 2 2 4 1 2 2 2 3 4
 1 4 1 3 1 4 1 3 1 2 3 3 4 1 2

In [22]:
labels.shape

(288,)

# EEGNet

In [35]:
# EEG Channel Attention (ECA)
class ECA(nn.Module):
    """EEG Channel Attention Mechanism."""
    def __init__(self, num_channels):
        super(ECA, self).__init__()
        self.layer_norm = nn.LayerNorm(num_channels)
        self.fc = nn.Sequential(
            nn.Linear(num_channels, num_channels, bias=False),
            nn.Softmax(dim=-1)
        )
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        # x shape: [batch_size, sequence_length, num_channels]
        batch_size, num_channels, sequence_length = x.size()
        x = x.permute(0, 2, 1)  # Change to [batch_size, sequence_length, num_channels]
        
        x_norm = self.layer_norm(x)  # Normalize across channels
        attention_weights = self.fc(x_norm)  # Compute attention weights
        x = x * attention_weights  # Apply attention weights
        
        x = x.permute(0, 2, 1)  # Revert back to [batch_size, num_channels, sequence_length]
        return self.dropout(x) + x

# Swin Transformer Block
class SwinTransformerBlock(nn.Module):
    """Swin Transformer Block."""
    def __init__(self, dim, num_heads, window_size=7, mlp_ratio=4.0, dropout=0.1):
        super(SwinTransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # x shape: [batch_size, dim, sequence_length]
        x = x.permute(0, 2, 1)  # Transpose to [batch_size, sequence_length, dim]

        # Attention block
        residual = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x)
        x = x + residual

        # MLP block
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + residual

        # Transpose back to [batch_size, dim, sequence_length]
        x = x.permute(0, 2, 1)
        return x

# EEGNet Model
class EEGNet(nn.Module):
    def __init__(self, n_classes, channels, samples, dim=64, depth=2, heads=4, dropout=0.3):
        super(EEGNet, self).__init__()
        self.channels = channels
        self.samples = samples
        self.dim = dim
        self.n_classes = n_classes

        # Feature extraction (initial projection)
        self.feature_proj = nn.Conv2d(1, dim, kernel_size=(1, 3), stride=1, padding=(0, 1))

        # EEG Channel Attention Mechanism
        self.eca = ECA(dim)

        # Swin Transformer Blocks
        self.transformer_blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, num_heads=heads, dropout=dropout) for _ in range(depth)
        ])

        # Classification Head
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(dim * samples, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, n_classes),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        # Initial projection
        x = self.feature_proj(x)  # Shape: [batch, dim, channels, time]

        # Reshape for ECA: Combine channels and time
        x = x.mean(dim=2)  # Pool across the channel dimension
        x = self.eca(x)  # Apply EEG Channel Attention

        # Swin Transformer Blocks
        for block in self.transformer_blocks:
            x = block(x)

        # Classification
        return self.classifier(x)

In [36]:
torch.manual_seed(2003)

# Data paths and parameters
data_path = ['A0'+str(i)+j for i in range(1, 10) for j in ['T', 'E']]
epoches = 350
accuracy_train = np.zeros((len(data_path) // 2, 1))

# Training Loop
for sub in range(len(data_path) // 2):
    # Load training data
    data = np.load(f'2a_pre/{data_path[2 * sub]}.npz')
    predata = torch.Tensor(data['data'])
    label = torch.Tensor(data['label']).long() - 1

    # Load validation data
    val_data = np.load(f'2a_pre/{data_path[2 * sub + 1]}.npz')
    val_predata = torch.Tensor(val_data['data'])
    val_label = torch.Tensor(val_data['label']).long() - 1

    # Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Model, loss, and optimizer
    dim = 64
    num_heads = 4
    net = EEGNet(n_classes=4, channels=22, samples=1125, dim=dim, depth=2, heads=num_heads, dropout=0.3).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    # Reshape data
    inpt = predata.reshape(-1, 1, 22, 1125)
    val_inpt = val_predata.reshape(-1, 1, 22, 1125)
    dataset = TensorDataset(inpt, label)
    val_dataset = TensorDataset(val_inpt, val_label)

    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)

    # Training
    train_losses, val_losses = [], []
    for epoch in range(epoches):
        net.train()
        train_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_losses.append(train_loss / len(train_loader))

        # Validation
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            for val_inputs, val_targets in val_loader:
                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
                val_outputs = net(val_inputs)
                val_loss += criterion(val_outputs, val_targets).item()
        val_losses.append(val_loss / len(val_loader))

    # Loss curves
    plt.plot(range(epoches), train_losses, label='Train Loss')
    plt.plot(range(epoches), val_losses, label='Val Loss')
    plt.legend()
    plt.show()

    # Accuracy and Confusion Matrix
    all_targets, all_predictions = [], []
    net.eval()
    with torch.no_grad():
        for val_inputs, val_targets in val_loader:
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
            val_outputs = net(val_inputs)
            _, predicted = torch.max(val_outputs, 1)
            all_targets.extend(val_targets.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
    cm = confusion_matrix(all_targets, all_predictions)
    accuracy_train[sub] = np.trace(cm) / np.sum(cm)

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.show()

print(f"Mean accuracy: {np.mean(accuracy_train) * 100:.2f}%")


cpu


KeyboardInterrupt: 