In [107]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [108]:
import mne
from moabb.datasets import BNCI2014_001
import numpy as np
from ComfyNet import ComfyNet
from sklearn.model_selection import train_test_split
import torch
import random
from torch import nn
from tqdm import tqdm

In [109]:
mne.set_config("MNE_DATA","bciData")

subjects = [2,3,4,5,6,7,8,9]

dataset = BNCI2014_001().get_data(subjects=subjects)

In [163]:
def get_data(subjects, dataset):

    X_bci = []
    y_bci = []

    global bci_n_channels, bci_n_samples, bci_sfreq, n_classes_bci, bci_ch_names
    
    bci_ch_names = dataset[subjects[0]]['0train']['1'].ch_names
    bci_n_channels = len(bci_ch_names)
    bci_n_samples = dataset[subjects[0]]['0train']['1'].n_times
    bci_sfreq = 250

    for subject in dataset:
        for session in dataset[subject]:
            for trial in dataset[subject][session]:
                raw = dataset[subject][session][trial]
                events, event_id = mne.events_from_annotations(raw, verbose = False)
                epochs = mne.Epochs(raw, events, event_id, 2,6, baseline=None, preload=True, verbose = False)
                X_bci.append(epochs.pick('eeg').get_data(copy=False))
                y_bci.append(epochs.events[:, 2])
                

    X_bci = np.array(X_bci)
    y_bci = np.array(y_bci)

    X_bci_reshaped = X_bci.reshape(-1, X_bci.shape[2], X_bci.shape[3])
    y_bci_reshaped = y_bci.reshape(-1)

    # 2 = binary classification, left and right hand
    # 3 = left hand, right hand, feet
    # 4 = left hand, right hand, feet, tongue
    n_classes_bci = 2

    # Binary classification for left and right hand
    if(n_classes_bci == 2):
        mask = (y_bci_reshaped != 3) & (y_bci_reshaped != 4)
    elif(n_classes_bci == 3):
        mask = (y_bci_reshaped != 4) # add feet

    y_bci_reshaped = y_bci_reshaped[mask]
    X_bci_reshaped = X_bci_reshaped[mask]

    # Reset labels to 0 and 1 as expected by pytorch
    y_bci_reshaped = y_bci_reshaped - 1

    return X_bci_reshaped, y_bci_reshaped

X_bci_reshaped, y_bci_reshaped = get_data(subjects)

In [111]:
iir_params = dict(order=3, ftype="cheby1", rp=1, output="sos")
low_cut_hz = 4.0
high_cut_hz = 40.0

# Apply band-pass filter
X_bci_reshaped = mne.filter.filter_data(
    data=X_bci_reshaped,
    method="iir",
    iir_params=iir_params,
    sfreq=bci_sfreq,
    l_freq=low_cut_hz,
    h_freq=high_cut_hz,
    phase="forward",
    n_jobs=-1,
)

# Apply z-score normalization on X
X_mean, X_std =  np.mean(X_bci_reshaped, axis=0),  np.std(X_bci_reshaped, axis=0)
X_reshaped = (X_bci_reshaped - X_mean) / X_std

# Split data for training and validation\
X_train, X_test, y_train, y_test = train_test_split(X_reshaped, y_bci_reshaped, test_size=0.1, shuffle=True)

Setting up band-pass filter from 4 - 40 Hz

IIR filter parameters
---------------------
Chebyshev I bandpass non-linear phase (one-pass forward) causal filter:
- Filter order 6 (forward)
- Cutoffs at 4.00, 40.00 Hz: -1.00, -1.00 dB



[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done 9200 tasks      | elapsed:    0.6s
[Parallel(n_jobs=-1)]: Done 49568 tasks      | elapsed:    2.5s
[Parallel(n_jobs=-1)]: Done 50688 out of 50688 | elapsed:    2.5s finished


In [112]:
class DatasetWrapped(torch.utils.data.Dataset):
    def __init__(self, X, Y, permute = False):
        # Convert to torch tensors
        self.X = torch.from_numpy(X).float()
        self.Y = torch.from_numpy(Y).long()
        self.permute = permute
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return(self.X[idx], self.Y[idx])
train_dataset_bci = DatasetWrapped(X_train, y_train, permute = True)
val_dataset_bci = DatasetWrapped(X_test, y_test, permute = False)

from braindecode.augmentation import AugmentedDataLoader, SignFlip, FrequencyShift, ChannelsShuffle
sfreq = 250
freq_shift = FrequencyShift(
    probability=0.5,
    sfreq=sfreq,
    max_delta_freq=2.0,  # the frequency shifts are sampled now between -2 and 2 Hz
)
channel_shuffle = ChannelsShuffle(0.5)

transforms = [freq_shift]

train_dataloader = AugmentedDataLoader(train_dataset_bci, batch_size = 72, transforms = transforms,  shuffle = True)
test_dataloader = AugmentedDataLoader(val_dataset_bci, batch_size = 72, shuffle= False)

In [113]:
# seed = 20231229

# random.seed(seed)
# torch.manual_seed(seed)
# np.random.seed(seed)

cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
    torch.backends.cudnn.benchmark = False
    # torch.cuda.manual_seed_all(seed)
else:
    print("Warning: CUDA is not available on this machine, fallback to CPU.")

# Extract number of chans and time steps from dataset
n_channels = X_train.shape[1]
input_window_samples = X_train.shape[2]

model = ComfyNet(
    n_outputs=n_classes_bci,
    n_chans=len(bci_ch_names[:22]),
    n_filters_time=16,
    filter_time_length=32,
    pool_time_length=75,
    pool_time_stride=15,
    drop_prob=0.5,
    att_depth=1,
    att_heads=2,
    att_drop_prob=0.5,
    return_features=False,
    n_times = input_window_samples,
    final_fc_length = "auto"
)

# Display torchinfo table describing the model
print(model)

# Send model to GPU
if cuda:
    model.cuda()

Layer (type (var_name):depth-idx)                       Input Shape               Output Shape              Param #                   Kernel Shape
ComfyNet (ComfyNet)                                     [1, 22, 1001]             [1, 2]                    --                        --
├─_PatchEmbedding (patch_embedding): 1-1                [1, 1, 22, 1001]          [1, 60, 16]               --                        --
│    └─Sequential (shallownet): 2-1                     [1, 1, 22, 1001]          [1, 16, 1, 60]            --                        --
│    │    └─Conv2d (0): 3-1                             [1, 1, 22, 1001]          [1, 16, 22, 970]          528                       [1, 32]
│    │    └─Conv2d (1): 3-2                             [1, 16, 22, 970]          [1, 16, 1, 970]           5,648                     [22, 1]
│    │    └─BatchNorm2d (2): 3-3                        [1, 16, 1, 970]           [1, 16, 1, 970]           32                        --
│    │    └─ELU (3): 



In [114]:
def train (model, optimizer,  loss, train_dataloader, test_dataloader, epochs = 2, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):

    pbar = tqdm(range(epochs), colour = 'green')

    for epoch in pbar:
            
        total_train_loss = 0
        train_acc = 0
        total_val_loss = 0
        val_acc = 0
        avg_train_loss = 0
        avg_val_loss = 0
        avg_train_acc=0
        avg_val_acc = 0

        trained_samples = 0


        model.train()
        for i, (X, y) in enumerate(train_dataloader):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(X)
            l = loss(y_pred, y)
            l.backward()
            optimizer.step()
            total_train_loss += l.item()
            train_acc += (y_pred.argmax(1) == y).sum().item()
            trained_samples+= X.shape[0]

            avg_train_loss = total_train_loss/len(train_dataloader)
            avg_train_acc = train_acc/trained_samples
        
        validated_samples = 0

        with torch.no_grad():
            # pbar = tqdm(test_dataloader, colour = 'red')
            for X, y in test_dataloader:
                X, y = X.to(device), y.to(device)
                y_pred = model(X)
                l = loss(y_pred, y)
                total_val_loss += l.item()
                validated_samples += X.shape[0]
                val_acc += (y_pred.argmax(1) == y).sum().item()
                
            avg_val_loss = total_val_loss/len(test_dataloader)
            avg_val_acc = val_acc/validated_samples
            # print(f"Validation Loss: {avg_val_loss :.3f}, Validation Accuracy: {avg_val_acc :.3f}")
        pbar.set_description(f"Epoch {epoch+1}, Loss: {avg_train_loss :.3f}, Train Accuracy: {avg_train_acc :.3f}, Validation Loss: {avg_val_loss :.3f}, Validation Accuracy: {avg_val_acc :.3f}")


    return model, avg_train_loss, avg_train_acc, avg_val_loss, avg_val_acc
        
        # wandb.log({'train_acc':avg_train_acc,"train_loss":avg_train_loss, "val_loss" : avg_val_loss, "val_acc": avg_val_acc})

def test(model,loss, test_dataloader, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    model.eval()
    test_acc = 0
    total_test_loss = 0
    test_samples = 0
    with torch.no_grad():
        pbar = tqdm(test_dataloader, colour = 'red')
        for i, (X, y) in enumerate(pbar):
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            l = loss(y_pred, y)
            total_test_loss += l.item()
            test_samples += X.shape[0]
            test_acc += (y_pred.argmax(1) == y).sum().item()
        avg_test_loss = total_test_loss/len(test_dataloader)
        avg_test_acc = test_acc/test_samples
        print(f"Test Loss: {avg_test_loss :.3f}, Test Accuracy: {avg_test_acc :.3f}")
        return avg_test_loss, avg_test_acc


In [115]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, betas = (0.9,0.999))
loss = nn.NLLLoss()
train(model, optimizer, loss, train_dataloader, test_dataloader, epochs=200)

Epoch 200, Loss: 0.280, Train Accuracy: 0.868, Validation Loss: 0.424, Validation Accuracy: 0.805: 100%|[32m██████████[0m| 200/200 [03:56<00:00,  1.18s/it]


(ComfyNet(
   (patch_embedding): _PatchEmbedding(
     (shallownet): Sequential(
       (0): Conv2d(1, 16, kernel_size=(1, 32), stride=(1, 1))
       (1): Conv2d(16, 16, kernel_size=(22, 1), stride=(1, 1))
       (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (3): ELU(alpha=1.0)
       (4): AvgPool2d(kernel_size=(1, 75), stride=(1, 15), padding=0)
       (5): Dropout(p=0.5, inplace=False)
     )
     (projection): Sequential(
       (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
       (1): Rearrange('b d_model 1 seq -> b seq d_model')
     )
   )
   (transformer): _TransformerEncoder(
     (0): _TransformerEncoderBlock(
       (0): _ResidualAdd(
         (fn): Sequential(
           (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
           (1): _MultiHeadAttention(
             (keys): Linear(in_features=16, out_features=16, bias=True)
             (queries): Linear(in_features=16, out_features=16, bias=True)
            

In [301]:
i = 6
checkpoint = torch.load(f"models/model_{i}.pt")
model.load_state_dict(checkpoint['model_state_dict'])

subjects = [i]
dataset = BNCI2014_001().get_data(subjects=subjects)
X, y = get_data(subjects, dataset)

# Apply band-pass filter
X = mne.filter.filter_data(
    data=X,
    method="iir",
    iir_params=iir_params,
    sfreq=bci_sfreq,
    l_freq=low_cut_hz,
    h_freq=high_cut_hz,
    phase="forward",
    n_jobs=-1,
)

X_norm = (X - X_mean) / X_std
test_dataset = DatasetWrapped(X_norm, y, permute = False)
test_dataloader = AugmentedDataLoader(test_dataset, batch_size = 16, shuffle = False)
test(model, loss, test_dataloader)

Setting up band-pass filter from 4 - 40 Hz

IIR filter parameters
---------------------
Chebyshev I bandpass non-linear phase (one-pass forward) causal filter:
- Filter order 6 (forward)
- Cutoffs at 4.00, 40.00 Hz: -1.00, -1.00 dB



[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done 176 tasks      | elapsed:    0.0s
[Parallel(n_jobs=-1)]: Done 4332 tasks      | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done 6096 tasks      | elapsed:    0.3s
[Parallel(n_jobs=-1)]: Done 6336 out of 6336 | elapsed:    0.3s finished
100%|[31m██████████[0m| 18/18 [00:00<00:00, 124.97it/s]

Test Loss: 3.546, Test Accuracy: 0.542





(3.5464281506008573, 0.5416666666666666)

In [302]:
X_train, X_test, y_train, y_test = train_test_split(X_norm, y, test_size=0.4, shuffle=True)
train_dataset = DatasetWrapped(X_train, y_train, permute = True)
val_dataset = DatasetWrapped(X_test, y_test, permute = False)
train_dataloader = AugmentedDataLoader(train_dataset, batch_size = 64,  shuffle = True)
val_dataloader = AugmentedDataLoader(val_dataset, batch_size = 64, shuffle = False)


In [303]:
from copy import deepcopy
model2 = deepcopy(model)

In [304]:
optimizer = torch.optim.AdamW(model2.parameters(), lr=10e-5)
loss = nn.NLLLoss()
train(model2, optimizer, loss, train_dataloader, val_dataloader, epochs=30)

Epoch 30, Loss: 0.377, Train Accuracy: 0.837, Validation Loss: 0.418, Validation Accuracy: 0.853: 100%|[32m██████████[0m| 30/30 [00:04<00:00,  7.27it/s]


(ComfyNet(
   (patch_embedding): _PatchEmbedding(
     (shallownet): Sequential(
       (0): Conv2d(1, 16, kernel_size=(1, 32), stride=(1, 1))
       (1): Conv2d(16, 16, kernel_size=(22, 1), stride=(1, 1))
       (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (3): ELU(alpha=1.0)
       (4): AvgPool2d(kernel_size=(1, 75), stride=(1, 15), padding=0)
       (5): Dropout(p=0.5, inplace=False)
     )
     (projection): Sequential(
       (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
       (1): Rearrange('b d_model 1 seq -> b seq d_model')
     )
   )
   (transformer): _TransformerEncoder(
     (0): _TransformerEncoderBlock(
       (0): _ResidualAdd(
         (fn): Sequential(
           (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
           (1): _MultiHeadAttention(
             (keys): Linear(in_features=16, out_features=16, bias=True)
             (queries): Linear(in_features=16, out_features=16, bias=True)
            

In [305]:
test(model2, loss, val_dataloader)

100%|[31m██████████[0m| 2/2 [00:00<00:00, 99.98it/s]

Test Loss: 0.397, Test Accuracy: 0.836





(0.39678336679935455, 0.8362068965517241)

In [306]:
test(model, loss, val_dataloader)   

100%|[31m██████████[0m| 2/2 [00:00<00:00, 76.91it/s]

Test Loss: 3.422, Test Accuracy: 0.578





(3.422027826309204, 0.5775862068965517)

In [211]:
X_train.shape[2]

1001