In [1]:
import os
import sys
import wandb
import torch
import GPUtil
import numpy as np
from EEGNet import EEGNet
from torchinfo import summary
from torch.utils.data import DataLoader
from models.estformer.ESTFormer import ESTFormer
from sklearn.metrics import classification_report, confusion_matrix

sys.path.append('../../')
from utils.coco_data_handler import COCODataHandler
from utils.epoch_data_reader import EpochDataReader

In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # Force CUDA to use the GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use first GPU
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Enable memory optimization settings for PyTorch

In [3]:
# Check if CUDA is available
try:
    gpus = GPUtil.getGPUs()
    if gpus:
        print(f"GPUtil detected {len(gpus)} GPUs:")
        for i, gpu in enumerate(gpus):
            print(f"  GPU {i}: {gpu.name} (Memory: {gpu.memoryTotal}MB)")
        
        # Set default GPU
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in range(len(gpus))])
        print(f"Set CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
    else:
        print("GPUtil found no available GPUs")
except Exception as e:
    print(f"Error checking GPUs with GPUtil: {e}")

GPUtil detected 1 GPUs:
  GPU 0: NVIDIA GeForce RTX 3070 Laptop GPU (Memory: 8192.0MB)
Set CUDA_VISIBLE_DEVICES=0


In [4]:
# Check for CUDA availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Print available GPU memory
if torch.cuda.is_available():
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Available GPU memory: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

Using device: cuda
Total GPU memory: 8.59 GB
Available GPU memory: 0.00 GB


In [5]:
coco_data = COCODataHandler.get_instance()
coco_data.category_index

{'accessory': 0,
 'animal': 1,
 'appliance': 2,
 'electronic': 3,
 'food': 4,
 'furniture': 5,
 'indoor': 6,
 'kitchen': 7,
 'outdoor': 8,
 'person': 9,
 'sports': 10,
 'vehicle': 11}

In [6]:
labels = list(coco_data.category_index.keys())

In [7]:
lr_count = 14 # 8, 4

In [8]:
identifier = f"cross_PO_650ms_29_{lr_count}"
identifier

'cross_PO_650ms_29_14'

In [9]:
hr_channel_names = [
    'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 
    'Iz', 'Oz', 'POz', 'Pz', 'CPz', 
    'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'
]

# Select 32 channels for the downsampled (low-resolution) set
# This selection preserves the overall spatial coverage while reducing density
if lr_count == 4:
    lr_channel_names = [
        'PO7',
        'POz', 'Oz',
        'PO8',
    ]

if lr_count == 8:
    lr_channel_names = [
        'P7', 'PO7', 'O1',
        'POz', 'Oz',
        'P8', 'PO8', 'O2'
    ]

if lr_count == 14:
    lr_channel_names = [
        'CP3', 'P3', 'P7', 'PO7', 'O1',
        'CPz', 'Pz', 'POz', 'Oz',
        'CP4', 'P4', 'P8', 'PO8', 'O2'
    ]

len(hr_channel_names), len(lr_channel_names)

(29, 14)

In [22]:
lo_res_dataset = EpochDataReader(
    channel_names=lr_channel_names
)

hi_res_dataset = EpochDataReader(
    channel_names=hr_channel_names
)

Creating new group: cross/ground-truth/CP3-CP4-CPz-O1-O2-Oz-P3-P4-P7-P8-PO7-PO8-POz-Pz/512/around_evoked/0.65/70_25_5/97
Opening raw data file s:\PolySecLabProjects\eeg-image-decoding\code\utils\..\..\data\all-joined-1\eeg\preprocessed\ground-truth\subj01_session1_eeg.fif...
    Range : 1121 ... 1777926 =      2.189 ...  3472.512 secs
Ready.
Opening raw data file s:\PolySecLabProjects\eeg-image-decoding\code\utils\..\..\data\all-joined-1\eeg\preprocessed\ground-truth\subj01_session1_eeg.fif...
    Range : 1121 ... 1777926 =      2.189 ...  3472.512 secs
Ready.
3839 events found on stim channel Status
Event IDs: [  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36
  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54
  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72
  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90
  91  92  93  94

In [23]:
batch_size = 64

lo_res_loader = DataLoader(
    lo_res_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

hi_res_loader = DataLoader(
    hi_res_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

len(lo_res_loader), len(hi_res_loader)

(719, 719)

## Downsampled EEGNet Training

In [24]:
num_classes = len(labels)

# Training parameters
epochs = 100
lr_eegnet = 5e-5

sample_item = lo_res_dataset[0][0]
num_channels = sample_item.shape[0]
time_steps = sample_item.shape[1]
sfreq = lo_res_dataset.resample_freq

config = {
    "total_epochs_trained_on": epochs,
    "model_type": "low-resolution (downsampled)", # "super-resolution (upsampled)" | "high-resolution (ground-truth)",
    "time_steps_in_seconds": time_steps / sfreq,
    "model_params": {
        "model": "EEGNet",
        "num_channels": num_channels,
        "num_classes": num_classes,
        "time_steps": time_steps,
        "builtin_montage": "standard_1020",
    },
    "dataset_params": {
        "subject_session_id": lo_res_dataset.subject_session_id,
        "epoch_type": lo_res_dataset.epoch_type,
        "split": lo_res_dataset.split,
        "duration": str((lo_res_dataset.before + lo_res_dataset.after) * 1000) + 'ms' if lo_res_dataset.epoch_type == 'around_evoked' else lo_res_dataset.fixed_length_duration,
        "batch_size": batch_size,
        "random_state": lo_res_dataset.random_state
    },
    "optimizer_params": {
        "optimizer": "Adam",
        "learning_rate": lr_eegnet,
        # "weight_decay": weight_decay,
        # "betas": (beta_1, beta_2)
    }
}

In [25]:
lo_res_eegnet = EEGNet(device, num_channels, time_steps, num_classes)
summary(lo_res_eegnet)

Layer (type:depth-idx)                   Param #
EEGNet                                   --
├─Ensure4d: 1-1                          --
├─Expression: 1-2                        --
├─Conv2d: 1-3                            512
├─BatchNorm2d: 1-4                       16
├─Conv2dWithConstraint: 1-5              224
├─BatchNorm2d: 1-6                       32
├─Expression: 1-7                        --
├─AvgPool2d: 1-8                         --
├─Dropout: 1-9                           --
├─Conv2d: 1-10                           256
├─Conv2d: 1-11                           256
├─BatchNorm2d: 1-12                      32
├─Expression: 1-13                       --
├─AvgPool2d: 1-14                        --
├─Dropout: 1-15                          --
├─Flatten: 1-16                          --
├─Linear: 1-17                           82,432
├─Linear: 1-18                           6,156
Total params: 89,916
Trainable params: 89,916
Non-trainable params: 0

In [26]:
eegnet_optimizer = torch.optim.Adam(
    params=[{'params': lo_res_eegnet.parameters()}], 
    lr=lr_eegnet
)

with wandb.init(project="eeg-eegnet", config=config) as run:
    history = lo_res_eegnet.fit(lo_res_loader, 1, eegnet_optimizer, 'cpoints', f'cross_PO_650ms_{lr_count}', use_checkpoint=True)

Loading checkpoint from cpoints\eegnet_cross_PO_650ms_14_best.pt


In [15]:
lo_res_loader.dataset.set_split_type('all')

all_preds = []
all_targets = []

lo_res_eegnet = lo_res_eegnet.to('cpu')

for batch in lo_res_loader:
    epochs = batch[0]
    one_hot_encoding = batch[1]
    y_pred = lo_res_eegnet.predict(epochs)

    all_preds.append(y_pred)
    all_targets.append(one_hot_encoding)

all_preds = np.concatenate(all_preds, axis=0)  
all_targets = np.concatenate(all_targets, axis=0)

In [16]:
print(classification_report(all_targets, all_preds, target_names=labels))

              precision    recall  f1-score   support

   accessory       0.00      0.00      0.00      5599
      animal       0.00      0.00      0.00     10825
   appliance       0.00      0.00      0.00      3160
  electronic       0.00      0.00      0.00      3878
        food       0.00      0.00      0.00      6086
   furniture       0.00      0.00      0.00      9770
      indoor       0.00      0.00      0.00      5651
     kitchen       0.00      0.00      0.00      6558
     outdoor       0.00      0.00      0.00      4358
      person       0.51      0.99      0.67     23413
      sports       0.00      0.00      0.00      9479
     vehicle       0.00      0.00      0.00     10721

   micro avg       0.51      0.23      0.32     99498
   macro avg       0.04      0.08      0.06     99498
weighted avg       0.12      0.23      0.16     99498
 samples avg       0.50      0.21      0.29     99498



## Super-Resolution EEGNet Training

In [27]:
builtin_montage = 'standard_1020'
alpha_t = 0.60
alpha_s = 0.75
r_mlp = 4 # amplification factor for MLP layers
dropout_rate = 0.5
L_s = 1  # Number of spatial layers
L_t = 1  # Number of temporal layers

# Optimizer parameters
lr_est = 5e-5
weight_decay = 0.5
beta_1 = 0.9
beta_2 = 0.95

In [28]:
estformer = ESTFormer(
    device=device, 
    lr_channel_names=lr_channel_names,
    hr_channel_names=hr_channel_names,
    builtin_montage=builtin_montage,
    time_steps=time_steps,
    alpha_t=alpha_t,
    alpha_s=alpha_s,
    r_mlp=r_mlp,
    dropout_rate=dropout_rate,
    L_s=L_s,
    L_t=L_t
)

estformer_optimizer = torch.optim.Adam(
    params=[{'params': estformer.parameters()}], 
    lr=lr_est,
    weight_decay=weight_decay,
    betas=(beta_1, beta_2)
)

estformer.fit(
    epochs=0,
    lo_res_loader=lo_res_loader,
    hi_res_loader=hi_res_loader,
    optimizer=estformer_optimizer,
    checkpoint_dir='cpoints',
    identifier=identifier,
    use_checkpoint=True
)

Loading checkpoint from cpoints\estformer_cross_PO_650ms_29_14_best.pt
Resuming training from epoch 100


{'sigma1': [], 'sigma2': [], 'train_loss': [], 'val_loss': []}

In [29]:
lo_res_dataset.set_split_type("all")
hi_res_dataset.set_split_type("all")

In [30]:
if not hi_res_dataset.has_super_res(identifier):
    for item in lo_res_dataset:
        lo_res = item[0]
        one_hot_encoding = item[1]
        super_res = estformer.predict(torch.from_numpy(lo_res).to(estformer.device))
        super_res = super_res.to('cpu')
        item = (super_res, one_hot_encoding)
        hi_res_dataset.push_super_resolution_to_dataset(item, identifier)

KeyboardInterrupt: 

In [20]:
hi_res_dataset.group_name

'cross/super-resolution/CP1-CP2-CP3-CP4-CP5-CP6-CPz-Iz-O1-O2-Oz-P1-P10-P2-P3-P4-P5-P6-P7-P8-P9-PO3-PO4-PO7-PO8-POz-Pz-TP7-TP8/512/around_evoked/0.65/70_25_5/97/cross_PO_650ms_29_14'

In [33]:
hi_res_dataset.read_from_ground_truth()
hi_res_dataset[1][0]

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = 's:\PolySecLabProjects\eeg-image-decoding\code\utils\..\..\data\all-joined-1\eeg\cross-subjects.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [30]:
hi_res_dataset.read_from_super_resolution(identifier)
hi_res_dataset[2][0]

array([[ 0.00058829,  0.00054835,  0.00050487, ...,  0.00156583,
         0.00155693,  0.00155311],
       [ 0.00041326,  0.00035107,  0.00031993, ...,  0.00103917,
         0.00101362,  0.00097254],
       [-0.00020366, -0.00023284, -0.00023586, ..., -0.00044762,
        -0.00044813, -0.00047324],
       ...,
       [ 0.0065492 ,  0.00625042,  0.00635397, ...,  0.01922379,
         0.01915877,  0.0190217 ],
       [ 0.00367304,  0.0035335 ,  0.0036623 , ...,  0.00931563,
         0.00932548,  0.00934583],
       [ 0.00204899,  0.00202788,  0.00210595, ...,  0.00494561,
         0.00495264,  0.00495143]], shape=(29, 338), dtype=float32)

In [21]:
np.all(hi_res_dataset[0][0] == 0)

np.True_

In [27]:
sample_item = hi_res_dataset[0][0]
num_channels = sample_item.shape[0]
config["model_type"] = "super-resolution (upsampled)" # "low-resolution (downsampled)" | "high-resolution (ground-truth)"

In [32]:
super_res_eegnet = EEGNet(device, num_channels, time_steps, num_classes)
summary(super_res_eegnet)

Layer (type:depth-idx)                   Param #
EEGNet                                   --
├─Ensure4d: 1-1                          --
├─Expression: 1-2                        --
├─Conv2d: 1-3                            512
├─BatchNorm2d: 1-4                       16
├─Conv2dWithConstraint: 1-5              464
├─BatchNorm2d: 1-6                       32
├─Expression: 1-7                        --
├─AvgPool2d: 1-8                         --
├─Dropout: 1-9                           --
├─Conv2d: 1-10                           256
├─Conv2d: 1-11                           256
├─BatchNorm2d: 1-12                      32
├─Expression: 1-13                       --
├─AvgPool2d: 1-14                        --
├─Dropout: 1-15                          --
├─Flatten: 1-16                          --
├─Linear: 1-17                           82,432
├─Linear: 1-18                           6,156
Total params: 90,156
Trainable params: 90,156
Non-trainable params: 0

In [33]:
eegnet_optimizer = torch.optim.Adam(
    params=[{'params': super_res_eegnet.parameters()}], 
    lr=lr_eegnet
)

with wandb.init(project="eeg-eegnet", config=config) as run:
    history = super_res_eegnet.fit(hi_res_loader, 1, eegnet_optimizer, 'cpoints', identifier, use_checkpoint=True)

No checkpoint found at cpoints\eegnet_cross_PO_650ms_29_14_best.pt, starting from scratch


Epoch 1/1:   7%|▋         | 37/503 [00:28<06:02,  1.28it/s, acc=0.8164, loss=0.4642] 
Traceback (most recent call last):
  File "C:\Users\dubs2\AppData\Local\Temp\ipykernel_14076\1189358766.py", line 7, in <module>
    history = super_res_eegnet.fit(hi_res_loader, 1, eegnet_optimizer, 'cpoints', identifier, use_checkpoint=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decoding\code\models\eegnet\EEGNet.py", line 447, in fit
    train_metrics = self.train_pass(epoch)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decoding\code\models\eegnet\EEGNet.py", line 291, in train_pass
    logits = self(X)
             ^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decoding\env\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  Fil

KeyboardInterrupt: 

In [None]:
hi_res_loader.dataset.set_split_type('all')

all_preds = []
all_targets = []

hi_res_eegnet = hi_res_eegnet.to('cpu')

for batch in hi_res_loader:
    epochs = batch[0]
    one_hot_encoding = batch[1]
    y_pred = hi_res_eegnet.predict(epochs)

    all_preds.append(y_pred)
    all_targets.append(one_hot_encoding)

all_preds = np.concatenate(all_preds, axis=0)  
all_targets = np.concatenate(all_targets, axis=0)

In [None]:
print(classification_report(all_targets, all_preds, target_names=labels))

# High Res Ground Truth EEGNet Training

In [None]:
hi_res_dataset.set_read_from_write_to("ground-truth")
config["model_type"] =  "high-resolution (ground-truth)" # "low-resolution (downsampled)" | "super-resolution (upsampled)"

In [None]:
hi_res_eegnet = EEGNet(device, num_channels, time_steps, num_classes)
summary(hi_res_eegnet)

In [None]:
eegnet_optimizer = torch.optim.Adam(
    params=[{'params': hi_res_eegnet.parameters()}], 
    lr=lr_eegnet
)

with wandb.init(project="eeg-eegnet", config=config) as run:
    history = hi_res_eegnet.fit(hi_res_loader, 1, eegnet_optimizer, 'cpoints', 'cross_PO_650ms_29', use_checkpoint=True)

In [None]:
hi_res_loader.dataset.set_split_type('all')

all_preds = []
all_targets = []

hi_res_eegnet = hi_res_eegnet.to('cpu')

for batch in hi_res_loader:
    hi_res_epochs = batch[0]
    one_hot_encoding = batch[1]
    y_pred = hi_res_eegnet.predict(hi_res_epochs)

    all_preds.append(y_pred)
    all_targets.append(one_hot_encoding)

all_preds = np.concatenate(all_preds, axis=0)  
all_targets = np.concatenate(all_targets, axis=0)

In [None]:
print(classification_report(all_targets, all_preds, target_names=labels))