In [2]:
#### Sanity check on all models
from eegnet import EEGNet
from conformer import Conformer
from deepconvnet import DeepConvNet
from ctnnet import EEGTransformer as CTNNet
from EEGPT import LitEEGPTModel
from biot import BIOTClassifier
from bendr import BendrClassifier
from cbramod import CBraModClassifier
import numpy as np
import torch

fs=[128, 256]
data_length=[1,3]
for sampling_rate in fs:
    for data_l in data_length:  
        window_length = int(data_l*sampling_rate)
        eegnet = EEGNet(no_spatial_filters=4, no_channels=62, no_temporal_filters=8, temporal_length_1=int(sampling_rate/2), temporal_length_2=int(sampling_rate/128)*16, window_length=window_length, num_class=4, drop_out_ratio=0.50, pooling2=int(sampling_rate/32), pooling3=8)
        out1 = eegnet(torch.randn(16,62,window_length))
        print(f"@{sampling_rate}, {data_l}s, eegnet:", out1.shape)
        
        eegconformer = Conformer(num_channel=62, data_length=window_length, emb_size=40, depth=6, n_classes=4)
        out2 = eegconformer(torch.randn(16,62,window_length))
        print("eegconformer:", out2.shape)
        
        deepconvnet = DeepConvNet(number_channel=62, nb_classes=4, dropout_rate=0.5, sampling_rate=sampling_rate, data_length=window_length)
        out3 = deepconvnet(torch.randn(16,62,window_length))
        print("deepconvnet:",out3.shape)
        
        ctnnet = CTNNet(heads=4, emb_size=40, depth=6, number_class=4, number_channel=62, data_length=window_length, sampling_rate=sampling_rate)
        out4 = ctnnet(torch.randn(16,62,window_length))
        print("ctnnet:",out4.shape)


@128, 1s, eegnet: torch.Size([16, 4])
eegconformer: torch.Size([16, 4])
deepconvnet: torch.Size([16, 4])
ctnnet: torch.Size([16, 4])
@128, 3s, eegnet: torch.Size([16, 4])
eegconformer: torch.Size([16, 4])
deepconvnet: torch.Size([16, 4])
ctnnet: torch.Size([16, 4])
@256, 1s, eegnet: torch.Size([16, 4])
eegconformer: torch.Size([16, 4])
deepconvnet: torch.Size([16, 4])
ctnnet: torch.Size([16, 4])
@256, 3s, eegnet: torch.Size([16, 4])
eegconformer: torch.Size([16, 4])
deepconvnet: torch.Size([16, 4])
ctnnet: torch.Size([16, 4])


In [None]:
eegpt = LitEEGPTModel(load_path="/lustre1/project/stg_00160/eegpt/EEGPT/checkpoint/eegpt_mcae_58chs_4s_large4E.ckpt",chans_num=19, num_class=7, data_length=int(256*3))

biot = BIOTClassifier(
    input_eeg_channel=62,
    emb_size=256,
    heads=8,
    depth=4,
    n_classes=4,
    n_fft=200,
    hop_length=100,
    n_channels=18
)
biot.biot.load_state_dict(torch.load("/vsc-hard-mounts/leuven-data/343/vsc34340/BIOT-main/pretrained-models/EEG-six-datasets-18-channels.ckpt"))

bendr = BendrClassifier(num_class=4, num_channels=61, data_length=int(256*3), pre_trained_model_path="/lustre1/project/stg_00160/bendr/encoder.pt")

cbramod = CBraModClassifier(num_class=4, num_channel=61, data_length=int(200*3), pretrained_dir="/lustre1/project/stg_00160/cbramod/pretrained_weights.pth")

In [3]:
out1 = eegnet(torch.randn(16,62,384))
out2 = eegconformer(torch.randn(16,62,384))
out3 = deepconvnet(torch.randn(16,62,256))
out4 = ctnnet(torch.randn(16,62,384))

In [4]:
out = cbramod(torch.randn(16,61,3,200))
print(out.shape)

torch.Size([16, 4])


In [2]:
out5 = biot(torch.randn(16,62,400))

  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]


In [2]:
import torch
out1 = eegnet(torch.randn(16,62,384))
out2 = eegconformer(torch.randn(16,62,384))
out3 = deepconvnet(torch.randn(16,62,384))
out4 = ctnnet(torch.randn(16,62,384))

In [3]:
print(out1.shape, out2.shape, out3.shape, out4.shape)

torch.Size([16, 4]) torch.Size([16, 4]) torch.Size([16, 4]) torch.Size([16, 4])


In [4]:

chan_idx = np.array(range(19))
h = eegpt(torch.randn(16,19,int(256*3)),torch.from_numpy(chan_idx).type(torch.IntTensor))

In [6]:
import yaml
with open("/vsc-hard-mounts/leuven-data/343/vsc34340/new_eeg_mae/util/dataset_specs.yaml", 'r') as f:
    dataset_yaml = yaml.safe_load(f)
downstream_channels = dataset_yaml['upper_limb_motorexecution']['chan_names']
eegpt_channels = [      'FP1', 'FPZ', 'FP2', 
                        "AF7", 'AF3', 'AF4', "AF8", 
            'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 
        'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 
            'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 
        'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8',
             'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 
                      'PO7', "PO5", 'PO3', 'POZ', 'PO4', "PO6", 'PO8', 
                               'O1', 'OZ', 'O2', ]

eegpt_channels_lower = [ch.lower() for ch in eegpt_channels]

channels_keep = []
eegpt_channel_idx = []
# Filter standard channels that exist in montage
for ch_idx, ch in enumerate(downstream_channels):
    ch_lower = ch.lower()
    if ch_lower in eegpt_channels_lower:
        # Exact match
        eegpt_channel_idx.append( eegpt_channels_lower.index(ch_lower))
        channels_keep.append(ch_idx)
print("keep ", len(channels_keep), "channels for EEGPT")

keep  31 channels for EEGPT


In [None]:
def filter_channels_from_available_channels(downstream_channels, available_channels):
    available_channels_lower = [ch.lower() for ch in available_channels]
    channels_keep = []
    available_channel_idx = []
    for ch_idx, ch in enumerate(downstream_channels):
        ch_lower = ch.lower()
        if ch_lower in available_channels_lower:
            # Exact match
            available_channel_idx.append( available_channels_lower.index(ch_lower))
            channels_keep.append(ch_idx)
    return 